agent_chain_core/runnables/
schema.rs1use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10
11#[derive(Debug, Clone, Default, Serialize, Deserialize)]
16pub struct EventData {
17 #[serde(skip_serializing_if = "Option::is_none")]
26 pub input: Option<Value>,
27
28 #[serde(skip_serializing_if = "Option::is_none")]
32 pub error: Option<String>,
33
34 #[serde(skip_serializing_if = "Option::is_none")]
42 pub output: Option<Value>,
43
44 #[serde(skip_serializing_if = "Option::is_none")]
49 pub chunk: Option<Value>,
50}
51
52impl EventData {
53 pub fn new() -> Self {
55 Self::default()
56 }
57
58 pub fn with_input(mut self, input: Value) -> Self {
60 self.input = Some(input);
61 self
62 }
63
64 pub fn with_error(mut self, error: impl Into<String>) -> Self {
66 self.error = Some(error.into());
67 self
68 }
69
70 pub fn with_output(mut self, output: Value) -> Self {
72 self.output = Some(output);
73 self
74 }
75
76 pub fn with_chunk(mut self, chunk: Value) -> Self {
78 self.chunk = Some(chunk);
79 self
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct BaseStreamEvent {
102 pub event: String,
104
105 pub run_id: String,
110
111 #[serde(default, skip_serializing_if = "Vec::is_empty")]
118 pub tags: Vec<String>,
119
120 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
127 pub metadata: HashMap<String, Value>,
128
129 #[serde(default)]
140 pub parent_ids: Vec<String>,
141}
142
143impl BaseStreamEvent {
144 pub fn new(event: impl Into<String>, run_id: impl Into<String>) -> Self {
146 Self {
147 event: event.into(),
148 run_id: run_id.into(),
149 tags: Vec::new(),
150 metadata: HashMap::new(),
151 parent_ids: Vec::new(),
152 }
153 }
154
155 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
157 self.tags = tags;
158 self
159 }
160
161 pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
163 self.metadata = metadata;
164 self
165 }
166
167 pub fn with_parent_ids(mut self, parent_ids: Vec<String>) -> Self {
169 self.parent_ids = parent_ids;
170 self
171 }
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct StandardStreamEvent {
177 #[serde(flatten)]
179 pub base: BaseStreamEvent,
180
181 pub data: EventData,
185
186 pub name: String,
188}
189
190impl StandardStreamEvent {
191 pub fn new(
193 event: impl Into<String>,
194 run_id: impl Into<String>,
195 name: impl Into<String>,
196 ) -> Self {
197 Self {
198 base: BaseStreamEvent::new(event, run_id),
199 data: EventData::new(),
200 name: name.into(),
201 }
202 }
203
204 pub fn with_data(mut self, data: EventData) -> Self {
206 self.data = data;
207 self
208 }
209
210 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
212 self.base.tags = tags;
213 self
214 }
215
216 pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
218 self.base.metadata = metadata;
219 self
220 }
221
222 pub fn with_parent_ids(mut self, parent_ids: Vec<String>) -> Self {
224 self.base.parent_ids = parent_ids;
225 self
226 }
227}
228
229pub const CUSTOM_EVENT_TYPE: &str = "on_custom_event";
231
232#[derive(Debug, Clone, Serialize, Deserialize)]
234pub struct CustomStreamEvent {
235 #[serde(flatten)]
237 pub base: BaseStreamEvent,
238
239 pub name: String,
241
242 pub data: Value,
244}
245
246impl CustomStreamEvent {
247 pub fn new(run_id: impl Into<String>, name: impl Into<String>, data: Value) -> Self {
251 Self {
252 base: BaseStreamEvent::new(CUSTOM_EVENT_TYPE, run_id),
253 name: name.into(),
254 data,
255 }
256 }
257
258 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
260 self.base.tags = tags;
261 self
262 }
263
264 pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
266 self.base.metadata = metadata;
267 self
268 }
269
270 pub fn with_parent_ids(mut self, parent_ids: Vec<String>) -> Self {
272 self.base.parent_ids = parent_ids;
273 self
274 }
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
282#[serde(untagged)]
283pub enum StreamEvent {
284 Standard(StandardStreamEvent),
286 Custom(CustomStreamEvent),
288}
289
290impl StreamEvent {
291 pub fn event(&self) -> &str {
293 match self {
294 StreamEvent::Standard(e) => &e.base.event,
295 StreamEvent::Custom(e) => &e.base.event,
296 }
297 }
298
299 pub fn run_id(&self) -> &str {
301 match self {
302 StreamEvent::Standard(e) => &e.base.run_id,
303 StreamEvent::Custom(e) => &e.base.run_id,
304 }
305 }
306
307 pub fn name(&self) -> &str {
309 match self {
310 StreamEvent::Standard(e) => &e.name,
311 StreamEvent::Custom(e) => &e.name,
312 }
313 }
314
315 pub fn tags(&self) -> &[String] {
317 match self {
318 StreamEvent::Standard(e) => &e.base.tags,
319 StreamEvent::Custom(e) => &e.base.tags,
320 }
321 }
322
323 pub fn metadata(&self) -> &HashMap<String, Value> {
325 match self {
326 StreamEvent::Standard(e) => &e.base.metadata,
327 StreamEvent::Custom(e) => &e.base.metadata,
328 }
329 }
330
331 pub fn parent_ids(&self) -> &[String] {
333 match self {
334 StreamEvent::Standard(e) => &e.base.parent_ids,
335 StreamEvent::Custom(e) => &e.base.parent_ids,
336 }
337 }
338
339 pub fn is_custom(&self) -> bool {
341 matches!(self, StreamEvent::Custom(_))
342 }
343
344 pub fn is_standard(&self) -> bool {
346 matches!(self, StreamEvent::Standard(_))
347 }
348}
349
350impl From<StandardStreamEvent> for StreamEvent {
351 fn from(event: StandardStreamEvent) -> Self {
352 StreamEvent::Standard(event)
353 }
354}
355
356impl From<CustomStreamEvent> for StreamEvent {
357 fn from(event: CustomStreamEvent) -> Self {
358 StreamEvent::Custom(event)
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365
366 #[test]
367 fn test_event_data() {
368 let data = EventData::new()
369 .with_input(serde_json::json!("hello"))
370 .with_output(serde_json::json!("world"));
371
372 assert_eq!(data.input, Some(serde_json::json!("hello")));
373 assert_eq!(data.output, Some(serde_json::json!("world")));
374 assert!(data.error.is_none());
375 assert!(data.chunk.is_none());
376 }
377
378 #[test]
379 fn test_standard_stream_event() {
380 let event = StandardStreamEvent::new("on_chain_start", "run-123", "my_chain")
381 .with_tags(vec!["tag1".to_string()])
382 .with_data(EventData::new().with_input(serde_json::json!({"key": "value"})));
383
384 assert_eq!(event.base.event, "on_chain_start");
385 assert_eq!(event.base.run_id, "run-123");
386 assert_eq!(event.name, "my_chain");
387 assert_eq!(event.base.tags, vec!["tag1"]);
388 assert!(event.data.input.is_some());
389 }
390
391 #[test]
392 fn test_custom_stream_event() {
393 let event = CustomStreamEvent::new(
394 "run-456",
395 "my_custom_event",
396 serde_json::json!({
397 "custom_field": "custom_value"
398 }),
399 );
400
401 assert_eq!(event.base.event, CUSTOM_EVENT_TYPE);
402 assert_eq!(event.base.run_id, "run-456");
403 assert_eq!(event.name, "my_custom_event");
404 assert_eq!(
405 event.data,
406 serde_json::json!({"custom_field": "custom_value"})
407 );
408 }
409
410 #[test]
411 fn test_stream_event_enum() {
412 let standard =
413 StreamEvent::Standard(StandardStreamEvent::new("on_chain_end", "run-1", "chain"));
414 let custom = StreamEvent::Custom(CustomStreamEvent::new(
415 "run-2",
416 "custom",
417 serde_json::json!(null),
418 ));
419
420 assert!(standard.is_standard());
421 assert!(!standard.is_custom());
422 assert_eq!(standard.event(), "on_chain_end");
423 assert_eq!(standard.name(), "chain");
424
425 assert!(custom.is_custom());
426 assert!(!custom.is_standard());
427 assert_eq!(custom.event(), CUSTOM_EVENT_TYPE);
428 assert_eq!(custom.name(), "custom");
429 }
430
431 #[test]
432 fn test_stream_event_serialization() {
433 let event = StandardStreamEvent::new("on_chain_start", "run-123", "test_chain")
434 .with_data(EventData::new().with_input(serde_json::json!("input")));
435
436 let json = serde_json::to_string(&event).unwrap();
437 assert!(json.contains("on_chain_start"));
438 assert!(json.contains("run-123"));
439 assert!(json.contains("test_chain"));
440
441 let deserialized: StandardStreamEvent = serde_json::from_str(&json).unwrap();
442 assert_eq!(deserialized.base.event, "on_chain_start");
443 assert_eq!(deserialized.base.run_id, "run-123");
444 assert_eq!(deserialized.name, "test_chain");
445 }
446}