Skip to main content

camel_function/provider/
mod.rs

1use crate::pool::{RunnerHandle, RunnerPoolKey};
2use camel_api::{Exchange, function::*};
3use std::time::Duration;
4
5mod sealed {
6    pub trait Sealed {}
7}
8
9#[derive(Debug, Clone)]
10pub enum HealthReport {
11    Healthy,
12    Unhealthy(String),
13}
14
15#[derive(Debug, thiserror::Error)]
16pub enum ProviderError {
17    #[error("spawn failed: {0}")]
18    SpawnFailed(String),
19    #[error("health check failed: {0}")]
20    HealthFailed(String),
21    #[error("register failed: {0}")]
22    RegisterFailed(String),
23    #[error("unregister failed: {0}")]
24    UnregisterFailed(String),
25    #[error("invoke failed: {0}")]
26    InvokeFailed(String),
27    #[error("shutdown failed: {0}")]
28    ShutdownFailed(String),
29    #[error("boot timeout")]
30    BootTimeout,
31}
32
33#[async_trait::async_trait]
34pub(crate) trait FunctionProvider: Send + Sync + sealed::Sealed {
35    async fn spawn(&self, key: &RunnerPoolKey) -> Result<RunnerHandle, ProviderError>;
36    async fn shutdown(&self, handle: RunnerHandle) -> Result<(), ProviderError>;
37    async fn health(&self, handle: &RunnerHandle) -> Result<HealthReport, ProviderError>;
38    async fn register(
39        &self,
40        handle: &RunnerHandle,
41        def: &FunctionDefinition,
42    ) -> Result<(), ProviderError>;
43    async fn unregister(&self, handle: &RunnerHandle, id: &FunctionId)
44    -> Result<(), ProviderError>;
45    async fn invoke(
46        &self,
47        handle: &RunnerHandle,
48        id: &FunctionId,
49        ex: &Exchange,
50        timeout: Duration,
51    ) -> Result<ExchangePatch, ProviderError>;
52}
53
54pub mod container;
55pub mod fake {
56    use super::*;
57    use std::collections::{HashMap, HashSet};
58    use std::sync::atomic::{AtomicUsize, Ordering};
59    use std::sync::{Arc, Mutex};
60    use tokio_util::sync::CancellationToken;
61
62    #[derive(Debug, Clone, Default)]
63    pub struct FakeProviderConfig {
64        pub fail_on_spawn: bool,
65        pub fail_on_register: usize,
66        pub fail_on_health: bool,
67        pub invoke_response: Option<ExchangePatch>,
68        pub invoke_delay: Option<std::time::Duration>,
69    }
70
71    #[derive(Debug, Clone)]
72    pub enum FakeCall {
73        Spawn(RunnerPoolKey),
74        Shutdown(RunnerPoolKey),
75        Health(String),
76        Register(String, FunctionId),
77        Unregister(String, FunctionId),
78        Invoke(String, FunctionId),
79    }
80
81    pub struct FakeProvider {
82        pub config: Arc<Mutex<FakeProviderConfig>>,
83        pub calls: Arc<Mutex<Vec<FakeCall>>>,
84        pub registered: Arc<Mutex<HashMap<String, HashSet<FunctionId>>>>,
85        pub spawned: Arc<Mutex<Vec<RunnerPoolKey>>>,
86        pub shutdowns: Arc<Mutex<Vec<RunnerPoolKey>>>,
87        register_ok_count: Arc<Mutex<usize>>,
88        spawn_count: AtomicUsize,
89    }
90
91    impl FakeProvider {
92        pub fn new(config: FakeProviderConfig) -> Self {
93            Self {
94                config: Arc::new(Mutex::new(config)),
95                calls: Arc::new(Mutex::new(Vec::new())),
96                registered: Arc::new(Mutex::new(HashMap::new())),
97                spawned: Arc::new(Mutex::new(Vec::new())),
98                shutdowns: Arc::new(Mutex::new(Vec::new())),
99                register_ok_count: Arc::new(Mutex::new(0)),
100                spawn_count: AtomicUsize::new(0),
101            }
102        }
103
104        pub fn spawn_count(&self) -> usize {
105            self.spawn_count.load(Ordering::SeqCst)
106        }
107    }
108
109    impl super::sealed::Sealed for FakeProvider {}
110
111    #[async_trait::async_trait]
112    impl FunctionProvider for FakeProvider {
113        async fn spawn(&self, key: &RunnerPoolKey) -> Result<RunnerHandle, ProviderError> {
114            self.spawn_count.fetch_add(1, Ordering::SeqCst);
115            self.calls
116                .lock()
117                .expect("calls") // allow-unwrap
118                .push(FakeCall::Spawn(key.clone()));
119            self.spawned.lock().expect("spawned").push(key.clone()); // allow-unwrap
120            if self.config.lock().expect("config").fail_on_spawn {
121                // allow-unwrap
122                return Err(ProviderError::SpawnFailed("configured".into()));
123            }
124            Ok(RunnerHandle {
125                id: format!("fake-{}", key.runtime),
126                state: Arc::new(Mutex::new(crate::pool::RunnerState::Booting)),
127                cancel: CancellationToken::new(),
128            })
129        }
130
131        async fn shutdown(&self, handle: RunnerHandle) -> Result<(), ProviderError> {
132            self.calls
133                .lock()
134                .expect("calls") // allow-unwrap
135                .push(FakeCall::Shutdown(RunnerPoolKey {
136                    runtime: handle.id.replace("fake-", ""),
137                }));
138            self.shutdowns
139                .lock()
140                .expect("shutdowns") // allow-unwrap
141                .push(RunnerPoolKey {
142                    runtime: handle.id.replace("fake-", ""),
143                });
144            Ok(())
145        }
146
147        async fn health(&self, handle: &RunnerHandle) -> Result<HealthReport, ProviderError> {
148            self.calls
149                .lock()
150                .expect("calls") // allow-unwrap
151                .push(FakeCall::Health(handle.id.clone()));
152            if self.config.lock().expect("config").fail_on_health {
153                // allow-unwrap
154                return Ok(HealthReport::Unhealthy("configured".into()));
155            }
156            Ok(HealthReport::Healthy)
157        }
158
159        async fn register(
160            &self,
161            handle: &RunnerHandle,
162            def: &FunctionDefinition,
163        ) -> Result<(), ProviderError> {
164            self.calls
165                .lock()
166                .expect("calls") // allow-unwrap
167                .push(FakeCall::Register(handle.id.clone(), def.id.clone()));
168            let mut count = self.register_ok_count.lock().expect("count"); // allow-unwrap
169            let cfg = self.config.lock().expect("config").clone(); // allow-unwrap
170            if cfg.fail_on_register > 0 && *count >= cfg.fail_on_register {
171                return Err(ProviderError::RegisterFailed("configured".into()));
172            }
173            *count += 1;
174            self.registered
175                .lock()
176                .expect("registered") // allow-unwrap
177                .entry(handle.id.clone())
178                .or_default()
179                .insert(def.id.clone());
180            Ok(())
181        }
182
183        async fn unregister(
184            &self,
185            handle: &RunnerHandle,
186            id: &FunctionId,
187        ) -> Result<(), ProviderError> {
188            self.calls
189                .lock()
190                .expect("calls") // allow-unwrap
191                .push(FakeCall::Unregister(handle.id.clone(), id.clone()));
192            if let Some(set) = self
193                .registered
194                .lock()
195                .expect("registered") // allow-unwrap
196                .get_mut(&handle.id)
197            {
198                set.remove(id);
199            }
200            Ok(())
201        }
202
203        async fn invoke(
204            &self,
205            handle: &RunnerHandle,
206            id: &FunctionId,
207            _ex: &Exchange,
208            _timeout: Duration,
209        ) -> Result<ExchangePatch, ProviderError> {
210            self.calls
211                .lock()
212                .expect("calls") // allow-unwrap
213                .push(FakeCall::Invoke(handle.id.clone(), id.clone()));
214            let exists = self
215                .registered
216                .lock()
217                .expect("registered") // allow-unwrap
218                .get(&handle.id)
219                .map(|s| s.contains(id))
220                .unwrap_or(false);
221            if !exists {
222                return Err(ProviderError::InvokeFailed("not registered".into()));
223            }
224            let cfg = self.config.lock().expect("config").clone(); // allow-unwrap
225            if let Some(delay) = cfg.invoke_delay {
226                tokio::time::sleep(delay).await;
227            }
228            Ok(cfg.invoke_response.unwrap_or_default())
229        }
230    }
231}