Skip to main content

camel_function/provider/
mod.rs

1use crate::pool::{RunnerHandle, RunnerPoolKey};
2use camel_api::{Exchange, function::*};
3use std::time::Duration;
4
5mod sealed {
6    pub trait Sealed {}
7}
8
9/// Health status of a function provider instance.
10/// Named `FunctionHealthStatus` to avoid collision with `camel_api::FunctionHealthStatus`.
11#[derive(Debug, Clone)]
12pub enum FunctionHealthStatus {
13    Healthy,
14    Unhealthy(String),
15}
16
17#[derive(Debug, thiserror::Error)]
18pub enum ProviderError {
19    #[error("spawn failed: {0}")]
20    SpawnFailed(String),
21    #[error("health check failed: {0}")]
22    HealthFailed(String),
23    #[error("register failed: {0}")]
24    RegisterFailed(String),
25    #[error("unregister failed: {0}")]
26    UnregisterFailed(String),
27    #[error("invoke failed: {0}")]
28    InvokeFailed(String),
29    #[error("shutdown failed: {0}")]
30    ShutdownFailed(String),
31    #[error("boot timeout")]
32    BootTimeout,
33}
34
35#[async_trait::async_trait]
36pub(crate) trait FunctionProvider: Send + Sync + sealed::Sealed {
37    async fn spawn(&self, key: &RunnerPoolKey) -> Result<RunnerHandle, ProviderError>;
38    async fn shutdown(&self, handle: RunnerHandle) -> Result<(), ProviderError>;
39    async fn health(&self, handle: &RunnerHandle) -> Result<FunctionHealthStatus, ProviderError>;
40    async fn register(
41        &self,
42        handle: &RunnerHandle,
43        def: &FunctionDefinition,
44    ) -> Result<(), ProviderError>;
45    async fn unregister(&self, handle: &RunnerHandle, id: &FunctionId)
46    -> Result<(), ProviderError>;
47    async fn invoke(
48        &self,
49        handle: &RunnerHandle,
50        id: &FunctionId,
51        ex: &Exchange,
52        timeout: Duration,
53    ) -> Result<ExchangePatch, ProviderError>;
54}
55
56pub mod container;
57pub mod fake {
58    use super::*;
59    use std::collections::{HashMap, HashSet};
60    use std::sync::atomic::{AtomicUsize, Ordering};
61    use std::sync::{Arc, Mutex};
62    use tokio_util::sync::CancellationToken;
63
64    #[derive(Debug, Clone, Default)]
65    pub struct FakeProviderConfig {
66        pub fail_on_spawn: bool,
67        pub fail_on_register: usize,
68        pub fail_on_health: bool,
69        pub invoke_response: Option<ExchangePatch>,
70        pub invoke_delay: Option<std::time::Duration>,
71    }
72
73    #[derive(Debug, Clone)]
74    pub enum FakeCall {
75        Spawn(RunnerPoolKey),
76        Shutdown(RunnerPoolKey),
77        Health(String),
78        Register(String, FunctionId),
79        Unregister(String, FunctionId),
80        Invoke(String, FunctionId),
81    }
82
83    pub struct FakeProvider {
84        pub config: Arc<Mutex<FakeProviderConfig>>,
85        pub calls: Arc<Mutex<Vec<FakeCall>>>,
86        pub registered: Arc<Mutex<HashMap<String, HashSet<FunctionId>>>>,
87        pub spawned: Arc<Mutex<Vec<RunnerPoolKey>>>,
88        pub shutdowns: Arc<Mutex<Vec<RunnerPoolKey>>>,
89        register_ok_count: Arc<Mutex<usize>>,
90        spawn_count: AtomicUsize,
91    }
92
93    impl FakeProvider {
94        pub fn new(config: FakeProviderConfig) -> Self {
95            Self {
96                config: Arc::new(Mutex::new(config)),
97                calls: Arc::new(Mutex::new(Vec::new())),
98                registered: Arc::new(Mutex::new(HashMap::new())),
99                spawned: Arc::new(Mutex::new(Vec::new())),
100                shutdowns: Arc::new(Mutex::new(Vec::new())),
101                register_ok_count: Arc::new(Mutex::new(0)),
102                spawn_count: AtomicUsize::new(0),
103            }
104        }
105
106        pub fn spawn_count(&self) -> usize {
107            self.spawn_count.load(Ordering::SeqCst)
108        }
109    }
110
111    impl super::sealed::Sealed for FakeProvider {}
112
113    #[async_trait::async_trait]
114    impl FunctionProvider for FakeProvider {
115        async fn spawn(&self, key: &RunnerPoolKey) -> Result<RunnerHandle, ProviderError> {
116            self.spawn_count.fetch_add(1, Ordering::SeqCst);
117            self.calls
118                .lock()
119                .expect("calls") // allow-unwrap
120                .push(FakeCall::Spawn(key.clone()));
121            self.spawned.lock().expect("spawned").push(key.clone()); // allow-unwrap
122            if self.config.lock().expect("config").fail_on_spawn {
123                // allow-unwrap
124                return Err(ProviderError::SpawnFailed("configured".into()));
125            }
126            Ok(RunnerHandle {
127                id: format!("fake-{}", key.runtime),
128                state: Arc::new(Mutex::new(crate::pool::RunnerState::Booting)),
129                cancel: CancellationToken::new(),
130            })
131        }
132
133        async fn shutdown(&self, handle: RunnerHandle) -> Result<(), ProviderError> {
134            self.calls
135                .lock()
136                .expect("calls") // allow-unwrap
137                .push(FakeCall::Shutdown(RunnerPoolKey {
138                    runtime: handle.id.replace("fake-", ""),
139                }));
140            self.shutdowns
141                .lock()
142                .expect("shutdowns") // allow-unwrap
143                .push(RunnerPoolKey {
144                    runtime: handle.id.replace("fake-", ""),
145                });
146            Ok(())
147        }
148
149        async fn health(
150            &self,
151            handle: &RunnerHandle,
152        ) -> Result<FunctionHealthStatus, ProviderError> {
153            self.calls
154                .lock()
155                .expect("calls") // allow-unwrap
156                .push(FakeCall::Health(handle.id.clone()));
157            if self.config.lock().expect("config").fail_on_health {
158                // allow-unwrap
159                return Ok(FunctionHealthStatus::Unhealthy("configured".into()));
160            }
161            Ok(FunctionHealthStatus::Healthy)
162        }
163
164        async fn register(
165            &self,
166            handle: &RunnerHandle,
167            def: &FunctionDefinition,
168        ) -> Result<(), ProviderError> {
169            self.calls
170                .lock()
171                .expect("calls") // allow-unwrap
172                .push(FakeCall::Register(handle.id.clone(), def.id.clone()));
173            let mut count = self.register_ok_count.lock().expect("count"); // allow-unwrap
174            let cfg = self.config.lock().expect("config").clone(); // allow-unwrap
175            if cfg.fail_on_register > 0 && *count >= cfg.fail_on_register {
176                return Err(ProviderError::RegisterFailed("configured".into()));
177            }
178            *count += 1;
179            self.registered
180                .lock()
181                .expect("registered") // allow-unwrap
182                .entry(handle.id.clone())
183                .or_default()
184                .insert(def.id.clone());
185            Ok(())
186        }
187
188        async fn unregister(
189            &self,
190            handle: &RunnerHandle,
191            id: &FunctionId,
192        ) -> Result<(), ProviderError> {
193            self.calls
194                .lock()
195                .expect("calls") // allow-unwrap
196                .push(FakeCall::Unregister(handle.id.clone(), id.clone()));
197            if let Some(set) = self
198                .registered
199                .lock()
200                .expect("registered") // allow-unwrap
201                .get_mut(&handle.id)
202            {
203                set.remove(id);
204            }
205            Ok(())
206        }
207
208        async fn invoke(
209            &self,
210            handle: &RunnerHandle,
211            id: &FunctionId,
212            _ex: &Exchange,
213            _timeout: Duration,
214        ) -> Result<ExchangePatch, ProviderError> {
215            self.calls
216                .lock()
217                .expect("calls") // allow-unwrap
218                .push(FakeCall::Invoke(handle.id.clone(), id.clone()));
219            let exists = self
220                .registered
221                .lock()
222                .expect("registered") // allow-unwrap
223                .get(&handle.id)
224                .map(|s| s.contains(id))
225                .unwrap_or(false);
226            if !exists {
227                return Err(ProviderError::InvokeFailed("not registered".into()));
228            }
229            let cfg = self.config.lock().expect("config").clone(); // allow-unwrap
230            if let Some(delay) = cfg.invoke_delay {
231                tokio::time::sleep(delay).await;
232            }
233            Ok(cfg.invoke_response.unwrap_or_default())
234        }
235    }
236}