Skip to main content

ralph_workflow/json_parser/
stream_classifier.rs

1//! Stream event classifier for algorithmic detection of partial vs complete events.
2//!
3//! This module provides a classifier that can distinguish between different types
4//! of streaming events without prior knowledge of the specific protocol. It uses
5//! heuristics based on JSON structure and field names to make conservative decisions
6//! about event classification.
7
8use serde_json::Value;
9
10/// Classification of a streaming event
11///
12/// Represents the nature of a streaming event to inform how it should be
13/// processed and displayed to the user.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum StreamEventType {
16    /// Partial/delta content that should be accumulated
17    ///
18    /// These events contain incremental updates that need to be combined
19    /// with other events to form complete content.
20    Partial,
21
22    /// Complete, self-contained content
23    ///
24    /// These events contain full content that can be displayed independently.
25    Complete,
26
27    /// Control/metadata event
28    ///
29    /// These events provide session information (start/stop) or metadata
30    /// but don't contain user-facing content.
31    Control,
32}
33
34/// Result of event classification
35///
36/// Contains the classification along with extracted metadata about the event.
37#[derive(Debug, Clone)]
38pub struct ClassificationResult {
39    /// The classified event type
40    pub event_type: StreamEventType,
41    /// Detected event type name (e.g., "message", "delta", "error")
42    pub type_name: Option<String>,
43    /// The primary content field if found
44    pub content_field: Option<String>,
45}
46
47/// Stream event classifier
48///
49/// Analyzes JSON events to determine if they represent partial content,
50/// complete messages, or control events. Uses conservative heuristics to
51/// prefer showing content over hiding it.
52pub struct StreamEventClassifier {
53    /// Threshold for considering text content "substantial" enough to be complete
54    substantial_content_threshold: usize,
55}
56
57impl Default for StreamEventClassifier {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63impl StreamEventClassifier {
64    /// Create a new classifier with default settings
65    pub const fn new() -> Self {
66        Self {
67            substantial_content_threshold: 50,
68        }
69    }
70
71    /// Classify a JSON event
72    ///
73    /// # Arguments
74    /// * `value` - The parsed JSON value to classify
75    ///
76    /// # Returns
77    /// A `ClassificationResult` with the detected event type and metadata
78    pub fn classify(&self, value: &Value) -> ClassificationResult {
79        let Some(obj) = value.as_object() else {
80            return ClassificationResult {
81                event_type: StreamEventType::Complete,
82                type_name: None,
83                content_field: None,
84            };
85        };
86
87        let type_name = obj
88            .get("type")
89            .or_else(|| obj.get("event_type"))
90            .and_then(|v| v.as_str())
91            .map(std::string::ToString::to_string);
92
93        let is_delta = obj
94            .get("delta")
95            .and_then(serde_json::Value::as_bool)
96            .unwrap_or(false);
97
98        if Self::is_control_event(type_name.as_ref(), obj) {
99            return ClassificationResult {
100                event_type: StreamEventType::Control,
101                type_name,
102                content_field: None,
103            };
104        }
105
106        if self.is_partial_event(type_name.as_ref(), obj, is_delta) {
107            return ClassificationResult {
108                event_type: StreamEventType::Partial,
109                type_name,
110                content_field: Self::find_content_field(obj),
111            };
112        }
113
114        ClassificationResult {
115            event_type: StreamEventType::Complete,
116            type_name,
117            content_field: Self::find_content_field(obj),
118        }
119    }
120
121    fn is_control_event(type_name: Option<&String>, obj: &serde_json::Map<String, Value>) -> bool {
122        if let Some(name) = type_name {
123            let control_patterns = [
124                "start",
125                "started",
126                "init",
127                "initialize",
128                "stop",
129                "stopped",
130                "end",
131                "done",
132                "complete",
133                "error",
134                "fail",
135                "failed",
136                "failure",
137                "ping",
138                "pong",
139                "heartbeat",
140                "keepalive",
141                "metadata",
142                "meta",
143            ];
144
145            let name_lower = name.to_lowercase();
146            if control_patterns
147                .iter()
148                .any(|pattern| name_lower.contains(pattern))
149            {
150                return true;
151            }
152        }
153
154        let has_status = obj.contains_key("status") || obj.contains_key("error");
155        let has_content = Self::has_content_field(obj);
156        has_status && !has_content
157    }
158
159    fn is_partial_event(
160        &self,
161        type_name: Option<&String>,
162        obj: &serde_json::Map<String, Value>,
163        explicit_delta: bool,
164    ) -> bool {
165        if explicit_delta {
166            return true;
167        }
168
169        if let Some(name) = type_name {
170            let partial_patterns = [
171                "delta",
172                "partial",
173                "increment",
174                "chunk",
175                "progress",
176                "streaming",
177                "update",
178            ];
179
180            let name_lower = name.to_lowercase();
181            if partial_patterns
182                .iter()
183                .any(|pattern| name_lower.contains(pattern))
184            {
185                return true;
186            }
187        }
188
189        let delta_fields = ["delta", "partial", "increment"];
190        if delta_fields.iter().any(|field| {
191            obj.get(*field).is_some_and(|value| {
192                value.is_string()
193                    || value.is_array()
194                    || value.is_object()
195                    || (value.is_number() && value.as_i64() != Some(0))
196            })
197        }) {
198            return true;
199        }
200
201        if !explicit_delta
202            && (type_name.is_none()
203                || !type_name.as_ref().is_some_and(|n| {
204                    let n_lower = n.to_lowercase();
205                    n_lower.contains("delta")
206                        || n_lower.contains("partial")
207                        || n_lower.contains("chunk")
208                }))
209        {
210            if let Some(content) = Self::find_content_field(obj) {
211                if let Some(text) = obj.get(&content).and_then(|v| v.as_str()) {
212                    if text.len() < self.substantial_content_threshold {
213                        let text_lower = text.to_lowercase();
214                        let trimmed = text.trim();
215
216                        let complete_responses = [
217                            "ok",
218                            "okay",
219                            "yes",
220                            "no",
221                            "true",
222                            "false",
223                            "done",
224                            "finished",
225                            "complete",
226                            "success",
227                            "failed",
228                            "error",
229                            "warning",
230                            "info",
231                            "debug",
232                            "pending",
233                            "processing",
234                            "running",
235                            "none",
236                            "null",
237                            "empty",
238                        ];
239                        let is_complete_response = complete_responses.contains(&trimmed);
240
241                        let ends_with_terminal = trimmed.ends_with('.')
242                            || trimmed.ends_with('!')
243                            || trimmed.ends_with('?');
244
245                        let has_newline = text.contains('\n');
246
247                        let is_error_message = text_lower.contains("error:")
248                            || text_lower.contains("warning:")
249                            || text_lower.starts_with("error")
250                            || text_lower.starts_with("warning");
251
252                        if is_complete_response
253                            || ends_with_terminal
254                            || has_newline
255                            || is_error_message
256                        {
257                            return false;
258                        }
259
260                        return true;
261                    }
262                }
263            }
264        }
265
266        false
267    }
268
269    fn find_content_field(obj: &serde_json::Map<String, Value>) -> Option<String> {
270        let content_fields = [
271            "content",
272            "text",
273            "message",
274            "data",
275            "output",
276            "result",
277            "response",
278            "body",
279            "thinking",
280            "reasoning",
281            "delta",
282        ];
283
284        content_fields
285            .iter()
286            .find(|field| {
287                obj.get(**field)
288                    .is_some_and(|v| matches!(v, Value::String(_)))
289            })
290            .map(|f| f.to_string())
291    }
292
293    fn has_content_field(obj: &serde_json::Map<String, Value>) -> bool {
294        Self::find_content_field(obj).is_some()
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301    use serde_json::json;
302
303    #[test]
304    fn test_classify_delta_event() {
305        let classifier = StreamEventClassifier::new();
306        let event = json!({
307            "type": "content_block_delta",
308            "index": 0,
309            "delta": {"type": "text_delta", "text": "Hello"}
310        });
311
312        let result = classifier.classify(&event);
313        assert_eq!(result.event_type, StreamEventType::Partial);
314    }
315
316    #[test]
317    fn test_classify_control_event() {
318        let classifier = StreamEventClassifier::new();
319        let event = json!({
320            "type": "message_start",
321            "message": {"id": "msg_123"}
322        });
323
324        let result = classifier.classify(&event);
325        assert_eq!(result.event_type, StreamEventType::Control);
326    }
327
328    #[test]
329    fn test_classify_complete_message() {
330        let classifier = StreamEventClassifier::new();
331        let event = json!({
332            "type": "message",
333            "content": "This is a complete message with substantial content that should be displayed as is."
334        });
335
336        let result = classifier.classify(&event);
337        assert_eq!(result.event_type, StreamEventType::Complete);
338    }
339
340    #[test]
341    fn test_classify_explicit_delta_flag() {
342        let classifier = StreamEventClassifier::new();
343        let event = json!({
344            "type": "message",
345            "delta": true,
346            "content": "partial"
347        });
348
349        let result = classifier.classify(&event);
350        assert_eq!(result.event_type, StreamEventType::Partial);
351    }
352
353    #[test]
354    fn test_classify_error_event() {
355        let classifier = StreamEventClassifier::new();
356        let event = json!({
357            "type": "error",
358            "message": "Something went wrong"
359        });
360
361        let result = classifier.classify(&event);
362        assert_eq!(result.event_type, StreamEventType::Control);
363    }
364
365    #[test]
366    fn test_small_content_is_partial() {
367        let classifier = StreamEventClassifier::new();
368        let event = json!({
369            "type": "chunk",
370            "text": "Hi"
371        });
372
373        let result = classifier.classify(&event);
374        assert_eq!(result.event_type, StreamEventType::Partial);
375    }
376
377    #[test]
378    fn test_substantial_content_is_complete() {
379        let classifier = StreamEventClassifier::new();
380        let long_text = "This is a substantial message that exceeds the threshold and should be considered complete.".repeat(2);
381        let event = json!({
382            "type": "message",
383            "content": long_text
384        });
385
386        let result = classifier.classify(&event);
387        assert_eq!(result.event_type, StreamEventType::Complete);
388    }
389}