Skip to main content

kapsl_scheduler/
lib.rs

1pub mod cron_scheduler;
2pub mod gpu_executor;
3pub mod mesh_routing;
4pub mod priority;
5pub mod replica_pool;
6pub mod request;
7pub mod request_metadata;
8pub mod scheduler;
9
10// Re-export main types
11pub use cron_scheduler::{CronCallback, CronError, CronJob, CronJobInfo, CronSchedule, CronScheduler};
12pub use priority::Priority;
13pub use replica_pool::{PoolStrategy, ReplicaPool, ReplicaScheduler, ReplicaStats};
14pub use request_metadata::{determine_priority, RequestMetadata};
15pub use scheduler::{QueueOverflowPolicy, Scheduler};
16
17#[cfg(test)]
18mod tests {
19    use crate::priority::Priority;
20    use crate::scheduler::Scheduler;
21    use async_trait::async_trait;
22    use kapsl_engine_api::{
23        BinaryTensorPacket, Engine, EngineError, InferenceRequest, TensorDtype,
24    };
25    use std::sync::{Arc, Mutex};
26    use std::time::Duration;
27
28    struct MockEngine {
29        call_count: Arc<Mutex<usize>>,
30        delay: Option<Duration>,
31    }
32
33    impl MockEngine {
34        fn new() -> Self {
35            Self {
36                call_count: Arc::new(Mutex::new(0)),
37                delay: None,
38            }
39        }
40
41        fn with_delay(ms: u64) -> Self {
42            Self {
43                call_count: Arc::new(Mutex::new(0)),
44                delay: Some(Duration::from_millis(ms)),
45            }
46        }
47    }
48
49    #[async_trait]
50    impl Engine for MockEngine {
51        async fn load(&mut self, _model_path: &std::path::Path) -> Result<(), EngineError> {
52            Ok(())
53        }
54
55        fn infer(&self, request: &InferenceRequest) -> Result<BinaryTensorPacket, EngineError> {
56            if let Some(delay) = self.delay {
57                std::thread::sleep(delay);
58            }
59            let mut count = self.call_count.lock().unwrap();
60            *count += 1;
61            Ok(request.input.clone())
62        }
63
64        fn infer_stream(
65            &self,
66            request: &InferenceRequest,
67        ) -> std::pin::Pin<
68            Box<dyn futures::stream::Stream<Item = Result<BinaryTensorPacket, EngineError>> + Send>,
69        > {
70            let result = Ok(request.input.clone());
71            Box::pin(futures::stream::once(async move { result }))
72        }
73
74        fn unload(&mut self) {}
75
76        fn metrics(&self) -> kapsl_engine_api::EngineMetrics {
77            kapsl_engine_api::EngineMetrics::default()
78        }
79
80        fn health_check(&self) -> Result<(), EngineError> {
81            Ok(()) // Mock is always healthy
82        }
83    }
84
85    fn make_request() -> InferenceRequest {
86        InferenceRequest {
87            input: BinaryTensorPacket {
88                shape: vec![1, 1],
89                dtype: TensorDtype::Float32,
90                data: vec![0, 0, 0, 0],
91            },
92            additional_inputs: Vec::new(),
93            session_id: None,
94            metadata: None,
95            cancellation: None,
96        }
97    }
98
99    #[tokio::test]
100    async fn test_cpu_scheduling() {
101        let engine_handle: Arc<dyn Engine> = Arc::new(MockEngine::new());
102        let scheduler = Scheduler::new(vec![engine_handle], 2, 1, 1000, true, 1, 0, None);
103
104        let result = scheduler
105            .infer(make_request(), Priority::Throughput, true)
106            .await;
107        assert!(result.is_ok());
108    }
109
110    #[tokio::test]
111    async fn test_gpu_scheduling() {
112        let engine_handle: Arc<dyn Engine> = Arc::new(MockEngine::new());
113        let scheduler = Scheduler::new(vec![engine_handle], 2, 1, 1000, true, 1, 0, None);
114
115        let result = scheduler
116            .infer(make_request(), Priority::LatencyCritical, false)
117            .await;
118        assert!(result.is_ok());
119    }
120
121    #[tokio::test]
122    async fn test_fallback() {
123        let engine_handle: Arc<dyn Engine> = Arc::new(MockEngine::with_delay(50));
124        let scheduler = Scheduler::new(vec![engine_handle], 2, 1, 1000, true, 1, 0, None);
125
126        let result = scheduler
127            .infer(make_request(), Priority::Throughput, false)
128            .await;
129        assert!(result.is_ok());
130    }
131
132    #[tokio::test]
133    async fn test_cpu_queue_depth_tracking() {
134        // Use an engine with delay to observe queue depth
135        let engine_handle: Arc<dyn Engine> = Arc::new(MockEngine::with_delay(100));
136        let scheduler = Scheduler::new(vec![engine_handle], 2, 1, 1000, true, 1, 0, None);
137
138        // Initially 0
139        let (cpu, _gpu) = scheduler.get_queue_depth();
140        assert_eq!(cpu, 0);
141
142        // Start an inference in the background
143        let scheduler_clone = Arc::new(scheduler);
144        let s2 = scheduler_clone.clone();
145        let handle =
146            tokio::spawn(async move { s2.infer(make_request(), Priority::Throughput, true).await });
147
148        // Give it a moment to start
149        tokio::time::sleep(Duration::from_millis(10)).await;
150
151        // Should be 1
152        let (cpu, _gpu) = scheduler_clone.get_queue_depth();
153        assert_eq!(cpu, 1);
154
155        // Wait for it to finish
156        let _ = handle.await.unwrap();
157
158        // Should be 0 again
159        let (cpu, _gpu) = scheduler_clone.get_queue_depth();
160        assert_eq!(cpu, 0);
161    }
162}