1use std::sync::Arc;
4use std::time::Instant;
5
6use async_trait::async_trait;
7use futures::stream::{self, StreamExt};
8use uuid::Uuid;
9
10use crate::extensions::Extensions;
11use crate::stream::Observer;
12
13#[derive(Clone)]
16pub struct RunnableConfig {
17 pub recursion_limit: u32,
20 pub max_concurrency: usize,
22 pub tags: Vec<String>,
24 pub metadata: serde_json::Value,
26 pub observers: Vec<Arc<dyn Observer>>,
28 pub run_id: Uuid,
30 pub cancel_token: Option<tokio_util::sync::CancellationToken>,
32 pub deadline: Option<Instant>,
34 pub extras: Extensions,
36 pub parent_run_id: Option<Uuid>,
40}
41
42impl Default for RunnableConfig {
43 fn default() -> Self {
44 Self {
45 recursion_limit: 25,
46 max_concurrency: num_cpus::get().max(1),
47 tags: Vec::new(),
48 metadata: serde_json::Value::Null,
49 observers: Vec::new(),
50 run_id: Uuid::new_v4(),
51 cancel_token: None,
52 deadline: None,
53 extras: Extensions::new(),
54 parent_run_id: None,
55 }
56 }
57}
58
59impl RunnableConfig {
60 pub fn new() -> Self {
62 Self::default()
63 }
64
65 pub fn with_recursion_limit(mut self, n: u32) -> Self {
67 self.recursion_limit = n;
68 self
69 }
70
71 pub fn with_max_concurrency(mut self, n: usize) -> Self {
73 self.max_concurrency = n;
74 self
75 }
76
77 pub fn with_observer(mut self, o: Arc<dyn Observer>) -> Self {
79 self.observers.push(o);
80 self
81 }
82
83 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
85 self.tags.push(tag.into());
86 self
87 }
88
89 pub fn with_cancel_token(mut self, t: tokio_util::sync::CancellationToken) -> Self {
91 self.cancel_token = Some(t);
92 self
93 }
94
95 pub fn with_parent_run_id(mut self, id: Uuid) -> Self {
98 self.parent_run_id = Some(id);
99 self
100 }
101
102 pub fn emit(&self, event: &crate::stream::Event) {
104 for o in &self.observers {
105 o.on_event(event);
106 }
107 }
108
109 pub fn is_cancelled(&self) -> bool {
111 self.cancel_token
112 .as_ref()
113 .map(|t| t.is_cancelled())
114 .unwrap_or(false)
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn defaults_sane() {
124 let c = RunnableConfig::default();
125 assert_eq!(c.recursion_limit, 25);
126 assert!(c.max_concurrency >= 1);
127 assert!(c.observers.is_empty());
128 }
129
130 #[test]
131 fn builder_chains() {
132 let c = RunnableConfig::new()
133 .with_recursion_limit(10)
134 .with_max_concurrency(4)
135 .with_tag("prod");
136 assert_eq!(c.recursion_limit, 10);
137 assert_eq!(c.max_concurrency, 4);
138 assert_eq!(c.tags, vec!["prod"]);
139 }
140
141 #[test]
142 fn cancel_default_false() {
143 let c = RunnableConfig::default();
144 assert!(!c.is_cancelled());
145 }
146
147 #[test]
148 fn config_clones_with_extras_emptied() {
149 let mut c = RunnableConfig::default()
150 .with_recursion_limit(50)
151 .with_max_concurrency(8)
152 .with_tag("test");
153 c.extras.insert(42u32);
154 assert!(c.extras.contains::<u32>());
155
156 let cloned = c.clone();
157 assert_eq!(cloned.recursion_limit, 50);
158 assert_eq!(cloned.max_concurrency, 8);
159 assert_eq!(cloned.tags, vec!["test"]);
160 assert!(cloned.extras.is_empty());
162 }
163
164 #[test]
165 fn parent_run_id_default_is_none() {
166 assert!(RunnableConfig::default().parent_run_id.is_none());
167 }
168
169 #[test]
170 fn clone_for_subcall_sets_parent_run_id_to_self() {
171 use std::sync::Arc;
172 let parent = Arc::new(RunnableConfig::default());
173 let child = RunnableConfig::clone_for_subcall(&parent);
174 assert_eq!(child.parent_run_id, Some(parent.run_id));
175 assert_ne!(child.run_id, parent.run_id);
176 }
177
178 #[test]
179 fn with_parent_run_id_builder() {
180 let id = Uuid::new_v4();
181 let cfg = RunnableConfig::default().with_parent_run_id(id);
182 assert_eq!(cfg.parent_run_id, Some(id));
183 }
184}
185
186#[async_trait]
192pub trait Runnable<I, O>: Send + Sync
193where
194 I: Send + 'static,
195 O: Send + 'static,
196{
197 async fn invoke(&self, input: I, config: RunnableConfig) -> crate::Result<O>;
199
200 async fn batch(&self, inputs: Vec<I>, config: RunnableConfig) -> crate::Result<Vec<O>>
203 where
204 I: 'static,
205 O: 'static,
206 Self: Sized + Sync,
207 {
208 let concurrency = config.max_concurrency.max(1);
209 let cfg = Arc::new(config);
210 stream::iter(inputs)
211 .map(|input| {
212 let cfg = cfg.clone();
213 async move {
214 self.invoke(input, RunnableConfig::clone_for_subcall(&cfg))
215 .await
216 }
217 })
218 .buffer_unordered(concurrency)
219 .collect::<Vec<_>>()
220 .await
221 .into_iter()
222 .collect()
223 }
224
225 async fn stream(&self, input: I, config: RunnableConfig) -> crate::Result<RunnableStream<O>>
228 where
229 Self: Sized + Sync,
230 {
231 let result = self.invoke(input, config).await;
232 Ok(RunnableStream::once(result))
233 }
234
235 async fn stream_events(&self, input: I, config: RunnableConfig) -> crate::Result<EventStream>
238 where
239 I: serde::Serialize,
240 O: serde::Serialize,
241 Self: Sized + Sync,
242 {
243 let runnable = self.name().to_string();
244 let run_id = config.run_id;
245 let input_json = serde_json::to_value(&input).unwrap_or(serde_json::Value::Null);
246
247 let on_start = Event::OnStart {
248 runnable: runnable.clone(),
249 run_id,
250 input: input_json,
251 };
252 let result = self.invoke(input, config).await;
253 let on_end_or_err = match &result {
254 Ok(o) => Event::OnEnd {
255 runnable,
256 run_id,
257 output: serde_json::to_value(o).unwrap_or(serde_json::Value::Null),
258 },
259 Err(e) => Event::OnError {
260 error: e.to_string(),
261 run_id,
262 },
263 };
264
265 Ok(EventStream::new(stream::iter(vec![
266 on_start,
267 on_end_or_err,
268 ])))
269 }
270
271 fn name(&self) -> &str {
273 std::any::type_name::<Self>()
274 }
275
276 fn input_schema(&self) -> Option<serde_json::Value> {
278 None
279 }
280
281 fn output_schema(&self) -> Option<serde_json::Value> {
283 None
284 }
285}
286
287use crate::stream::{Event, EventStream, RunnableStream};
288
289impl RunnableConfig {
290 pub fn clone_for_subcall(parent: &Arc<RunnableConfig>) -> RunnableConfig {
294 RunnableConfig {
295 recursion_limit: parent.recursion_limit,
296 max_concurrency: parent.max_concurrency,
297 tags: parent.tags.clone(),
298 metadata: parent.metadata.clone(),
299 observers: parent.observers.clone(),
300 run_id: Uuid::new_v4(),
301 parent_run_id: Some(parent.run_id),
302 cancel_token: parent.cancel_token.clone(),
303 deadline: parent.deadline,
304 extras: Extensions::new(),
305 }
306 }
307}
308
309#[cfg(test)]
310mod runnable_tests {
311 use super::*;
312 use async_trait::async_trait;
313
314 struct Doubler;
315
316 #[async_trait]
317 impl Runnable<u32, u32> for Doubler {
318 async fn invoke(&self, input: u32, _: RunnableConfig) -> crate::Result<u32> {
319 Ok(input * 2)
320 }
321 }
322
323 #[tokio::test]
324 async fn invoke_works() {
325 let d = Doubler;
326 let out = d.invoke(5, RunnableConfig::default()).await.unwrap();
327 assert_eq!(out, 10);
328 }
329
330 #[tokio::test]
331 async fn default_batch_runs_each() {
332 let d = Doubler;
333 let out = d
334 .batch(vec![1, 2, 3, 4], RunnableConfig::default())
335 .await
336 .unwrap();
337 let mut sorted = out;
338 sorted.sort();
339 assert_eq!(sorted, vec![2, 4, 6, 8]);
340 }
341
342 #[tokio::test]
343 async fn default_stream_emits_one_item() {
344 let d = Doubler;
345 let s = d.stream(7, RunnableConfig::default()).await.unwrap();
346 let v = s.collect_into_vec().await.unwrap();
347 assert_eq!(v, vec![14]);
348 }
349
350 #[tokio::test]
351 async fn default_stream_events_emits_start_end() {
352 use futures::StreamExt;
353 let d = Doubler;
354 let mut s = d.stream_events(3, RunnableConfig::default()).await.unwrap();
355 let mut events = Vec::new();
356 while let Some(e) = s.next().await {
357 events.push(e);
358 }
359 assert_eq!(events.len(), 2);
360 assert!(matches!(events[0], Event::OnStart { .. }));
361 assert!(matches!(events[1], Event::OnEnd { .. }));
362 }
363
364 #[tokio::test]
365 async fn batch_respects_max_concurrency() {
366 let d = Doubler;
367 let cfg = RunnableConfig::default().with_max_concurrency(1);
368 let out = d.batch(vec![1, 2, 3], cfg).await.unwrap();
369 let mut sorted = out;
370 sorted.sort();
371 assert_eq!(sorted, vec![2, 4, 6]);
372 }
373}