agent_chain_core/runnables/
schema.rs

1//! Schema types for Runnables.
2//!
3//! This module contains typedefs that are used with `Runnable` objects,
4//! mirroring `langchain_core.runnables.schema`.
5
6use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10
11/// Data associated with a streaming event.
12///
13/// This struct contains optional fields that may be present depending
14/// on the event type (start, stream, end).
15#[derive(Debug, Clone, Default, Serialize, Deserialize)]
16pub struct EventData {
17    /// The input passed to the `Runnable` that generated the event.
18    ///
19    /// Inputs will sometimes be available at the *START* of the `Runnable`, and
20    /// sometimes at the *END* of the `Runnable`.
21    ///
22    /// If a `Runnable` is able to stream its inputs, then its input by definition
23    /// won't be known until the *END* of the `Runnable` when it has finished streaming
24    /// its inputs.
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub input: Option<Value>,
27
28    /// The error that occurred during the execution of the `Runnable`.
29    ///
30    /// This field is only available if the `Runnable` raised an exception.
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub error: Option<String>,
33
34    /// The output of the `Runnable` that generated the event.
35    ///
36    /// Outputs will only be available at the *END* of the `Runnable`.
37    ///
38    /// For most `Runnable` objects, this field can be inferred from the `chunk` field,
39    /// though there might be some exceptions for special a cased `Runnable` (e.g., like
40    /// chat models), which may return more information.
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub output: Option<Value>,
43
44    /// A streaming chunk from the output that generated the event.
45    ///
46    /// Chunks support addition in general, and adding them up should result
47    /// in the output of the `Runnable` that generated the event.
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub chunk: Option<Value>,
50}
51
52impl EventData {
53    /// Create a new empty EventData.
54    pub fn new() -> Self {
55        Self::default()
56    }
57
58    /// Create EventData with an input value.
59    pub fn with_input(mut self, input: Value) -> Self {
60        self.input = Some(input);
61        self
62    }
63
64    /// Create EventData with an error.
65    pub fn with_error(mut self, error: impl Into<String>) -> Self {
66        self.error = Some(error.into());
67        self
68    }
69
70    /// Create EventData with an output value.
71    pub fn with_output(mut self, output: Value) -> Self {
72        self.output = Some(output);
73        self
74    }
75
76    /// Create EventData with a chunk value.
77    pub fn with_chunk(mut self, chunk: Value) -> Self {
78        self.chunk = Some(chunk);
79        self
80    }
81}
82
83/// Base streaming event.
84///
85/// Schema of a streaming event which is produced from the `astream_events` method.
86///
87/// Event names are of the format: `on_[runnable_type]_(start|stream|end)`.
88///
89/// Runnable types are one of:
90/// - **llm** - used by non chat models
91/// - **chat_model** - used by chat models
92/// - **prompt** - e.g., `ChatPromptTemplate`
93/// - **tool** - from tools defined via `@tool` decorator or inheriting from `Tool`/`BaseTool`
94/// - **chain** - most `Runnable` objects are of this type
95///
96/// Further, the events are categorized as one of:
97/// - **start** - when the `Runnable` starts
98/// - **stream** - when the `Runnable` is streaming
99/// - **end** - when the `Runnable` ends
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct BaseStreamEvent {
102    /// Event names are of the format: `on_[runnable_type]_(start|stream|end)`.
103    pub event: String,
104
105    /// A randomly generated ID to keep track of the execution of the given `Runnable`.
106    ///
107    /// Each child `Runnable` that gets invoked as part of the execution of a parent
108    /// `Runnable` is assigned its own unique ID.
109    pub run_id: String,
110
111    /// Tags associated with the `Runnable` that generated this event.
112    ///
113    /// Tags are always inherited from parent `Runnable` objects.
114    ///
115    /// Tags can either be bound to a `Runnable` using `.with_config({"tags": ["hello"]})`
116    /// or passed at run time using `.astream_events(..., {"tags": ["hello"]})`.
117    #[serde(default, skip_serializing_if = "Vec::is_empty")]
118    pub tags: Vec<String>,
119
120    /// Metadata associated with the `Runnable` that generated this event.
121    ///
122    /// Metadata can either be bound to a `Runnable` using
123    /// `.with_config({"metadata": { "foo": "bar" }})`
124    /// or passed at run time using
125    /// `.astream_events(..., {"metadata": {"foo": "bar"}})`.
126    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
127    pub metadata: HashMap<String, Value>,
128
129    /// A list of the parent IDs associated with this event.
130    ///
131    /// Root Events will have an empty list.
132    ///
133    /// For example, if a `Runnable` A calls `Runnable` B, then the event generated by
134    /// `Runnable` B will have `Runnable` A's ID in the `parent_ids` field.
135    ///
136    /// The order of the parent IDs is from the root parent to the immediate parent.
137    ///
138    /// Only supported as of v2 of the astream events API. v1 will return an empty list.
139    #[serde(default)]
140    pub parent_ids: Vec<String>,
141}
142
143impl BaseStreamEvent {
144    /// Create a new BaseStreamEvent.
145    pub fn new(event: impl Into<String>, run_id: impl Into<String>) -> Self {
146        Self {
147            event: event.into(),
148            run_id: run_id.into(),
149            tags: Vec::new(),
150            metadata: HashMap::new(),
151            parent_ids: Vec::new(),
152        }
153    }
154
155    /// Set the tags for this event.
156    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
157        self.tags = tags;
158        self
159    }
160
161    /// Set the metadata for this event.
162    pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
163        self.metadata = metadata;
164        self
165    }
166
167    /// Set the parent IDs for this event.
168    pub fn with_parent_ids(mut self, parent_ids: Vec<String>) -> Self {
169        self.parent_ids = parent_ids;
170        self
171    }
172}
173
174/// A standard stream event that follows LangChain convention for event data.
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct StandardStreamEvent {
177    /// The base event fields.
178    #[serde(flatten)]
179    pub base: BaseStreamEvent,
180
181    /// Event data.
182    ///
183    /// The contents of the event data depend on the event type.
184    pub data: EventData,
185
186    /// The name of the `Runnable` that generated the event.
187    pub name: String,
188}
189
190impl StandardStreamEvent {
191    /// Create a new StandardStreamEvent.
192    pub fn new(
193        event: impl Into<String>,
194        run_id: impl Into<String>,
195        name: impl Into<String>,
196    ) -> Self {
197        Self {
198            base: BaseStreamEvent::new(event, run_id),
199            data: EventData::new(),
200            name: name.into(),
201        }
202    }
203
204    /// Set the data for this event.
205    pub fn with_data(mut self, data: EventData) -> Self {
206        self.data = data;
207        self
208    }
209
210    /// Set the tags for this event.
211    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
212        self.base.tags = tags;
213        self
214    }
215
216    /// Set the metadata for this event.
217    pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
218        self.base.metadata = metadata;
219        self
220    }
221
222    /// Set the parent IDs for this event.
223    pub fn with_parent_ids(mut self, parent_ids: Vec<String>) -> Self {
224        self.base.parent_ids = parent_ids;
225        self
226    }
227}
228
229/// The literal event type for custom events.
230pub const CUSTOM_EVENT_TYPE: &str = "on_custom_event";
231
232/// Custom stream event created by the user.
233#[derive(Debug, Clone, Serialize, Deserialize)]
234pub struct CustomStreamEvent {
235    /// The base event fields.
236    #[serde(flatten)]
237    pub base: BaseStreamEvent,
238
239    /// User defined name for the event.
240    pub name: String,
241
242    /// The data associated with the event. Free form and can be anything.
243    pub data: Value,
244}
245
246impl CustomStreamEvent {
247    /// Create a new CustomStreamEvent.
248    ///
249    /// The event type is automatically set to "on_custom_event".
250    pub fn new(run_id: impl Into<String>, name: impl Into<String>, data: Value) -> Self {
251        Self {
252            base: BaseStreamEvent::new(CUSTOM_EVENT_TYPE, run_id),
253            name: name.into(),
254            data,
255        }
256    }
257
258    /// Set the tags for this event.
259    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
260        self.base.tags = tags;
261        self
262    }
263
264    /// Set the metadata for this event.
265    pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
266        self.base.metadata = metadata;
267        self
268    }
269
270    /// Set the parent IDs for this event.
271    pub fn with_parent_ids(mut self, parent_ids: Vec<String>) -> Self {
272        self.base.parent_ids = parent_ids;
273        self
274    }
275}
276
277/// Union type for stream events.
278///
279/// A stream event can be either a standard event following LangChain conventions,
280/// or a custom event created by the user.
281#[derive(Debug, Clone, Serialize, Deserialize)]
282#[serde(untagged)]
283pub enum StreamEvent {
284    /// A standard stream event.
285    Standard(StandardStreamEvent),
286    /// A custom stream event.
287    Custom(CustomStreamEvent),
288}
289
290impl StreamEvent {
291    /// Get the event type string.
292    pub fn event(&self) -> &str {
293        match self {
294            StreamEvent::Standard(e) => &e.base.event,
295            StreamEvent::Custom(e) => &e.base.event,
296        }
297    }
298
299    /// Get the run ID.
300    pub fn run_id(&self) -> &str {
301        match self {
302            StreamEvent::Standard(e) => &e.base.run_id,
303            StreamEvent::Custom(e) => &e.base.run_id,
304        }
305    }
306
307    /// Get the name.
308    pub fn name(&self) -> &str {
309        match self {
310            StreamEvent::Standard(e) => &e.name,
311            StreamEvent::Custom(e) => &e.name,
312        }
313    }
314
315    /// Get the tags.
316    pub fn tags(&self) -> &[String] {
317        match self {
318            StreamEvent::Standard(e) => &e.base.tags,
319            StreamEvent::Custom(e) => &e.base.tags,
320        }
321    }
322
323    /// Get the metadata.
324    pub fn metadata(&self) -> &HashMap<String, Value> {
325        match self {
326            StreamEvent::Standard(e) => &e.base.metadata,
327            StreamEvent::Custom(e) => &e.base.metadata,
328        }
329    }
330
331    /// Get the parent IDs.
332    pub fn parent_ids(&self) -> &[String] {
333        match self {
334            StreamEvent::Standard(e) => &e.base.parent_ids,
335            StreamEvent::Custom(e) => &e.base.parent_ids,
336        }
337    }
338
339    /// Check if this is a custom event.
340    pub fn is_custom(&self) -> bool {
341        matches!(self, StreamEvent::Custom(_))
342    }
343
344    /// Check if this is a standard event.
345    pub fn is_standard(&self) -> bool {
346        matches!(self, StreamEvent::Standard(_))
347    }
348}
349
350impl From<StandardStreamEvent> for StreamEvent {
351    fn from(event: StandardStreamEvent) -> Self {
352        StreamEvent::Standard(event)
353    }
354}
355
356impl From<CustomStreamEvent> for StreamEvent {
357    fn from(event: CustomStreamEvent) -> Self {
358        StreamEvent::Custom(event)
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn test_event_data() {
368        let data = EventData::new()
369            .with_input(serde_json::json!("hello"))
370            .with_output(serde_json::json!("world"));
371
372        assert_eq!(data.input, Some(serde_json::json!("hello")));
373        assert_eq!(data.output, Some(serde_json::json!("world")));
374        assert!(data.error.is_none());
375        assert!(data.chunk.is_none());
376    }
377
378    #[test]
379    fn test_standard_stream_event() {
380        let event = StandardStreamEvent::new("on_chain_start", "run-123", "my_chain")
381            .with_tags(vec!["tag1".to_string()])
382            .with_data(EventData::new().with_input(serde_json::json!({"key": "value"})));
383
384        assert_eq!(event.base.event, "on_chain_start");
385        assert_eq!(event.base.run_id, "run-123");
386        assert_eq!(event.name, "my_chain");
387        assert_eq!(event.base.tags, vec!["tag1"]);
388        assert!(event.data.input.is_some());
389    }
390
391    #[test]
392    fn test_custom_stream_event() {
393        let event = CustomStreamEvent::new(
394            "run-456",
395            "my_custom_event",
396            serde_json::json!({
397                "custom_field": "custom_value"
398            }),
399        );
400
401        assert_eq!(event.base.event, CUSTOM_EVENT_TYPE);
402        assert_eq!(event.base.run_id, "run-456");
403        assert_eq!(event.name, "my_custom_event");
404        assert_eq!(
405            event.data,
406            serde_json::json!({"custom_field": "custom_value"})
407        );
408    }
409
410    #[test]
411    fn test_stream_event_enum() {
412        let standard =
413            StreamEvent::Standard(StandardStreamEvent::new("on_chain_end", "run-1", "chain"));
414        let custom = StreamEvent::Custom(CustomStreamEvent::new(
415            "run-2",
416            "custom",
417            serde_json::json!(null),
418        ));
419
420        assert!(standard.is_standard());
421        assert!(!standard.is_custom());
422        assert_eq!(standard.event(), "on_chain_end");
423        assert_eq!(standard.name(), "chain");
424
425        assert!(custom.is_custom());
426        assert!(!custom.is_standard());
427        assert_eq!(custom.event(), CUSTOM_EVENT_TYPE);
428        assert_eq!(custom.name(), "custom");
429    }
430
431    #[test]
432    fn test_stream_event_serialization() {
433        let event = StandardStreamEvent::new("on_chain_start", "run-123", "test_chain")
434            .with_data(EventData::new().with_input(serde_json::json!("input")));
435
436        let json = serde_json::to_string(&event).unwrap();
437        assert!(json.contains("on_chain_start"));
438        assert!(json.contains("run-123"));
439        assert!(json.contains("test_chain"));
440
441        let deserialized: StandardStreamEvent = serde_json::from_str(&json).unwrap();
442        assert_eq!(deserialized.base.event, "on_chain_start");
443        assert_eq!(deserialized.base.run_id, "run-123");
444        assert_eq!(deserialized.name, "test_chain");
445    }
446}