ralph_workflow/json_parser/
stream_classifier.rs1use serde_json::Value;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum StreamEventType {
16 Partial,
21
22 Complete,
26
27 Control,
32}
33
34#[derive(Debug, Clone)]
38pub struct ClassificationResult {
39 pub event_type: StreamEventType,
41 pub type_name: Option<String>,
43 pub content_field: Option<String>,
45}
46
47pub struct StreamEventClassifier {
53 substantial_content_threshold: usize,
55}
56
57impl Default for StreamEventClassifier {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl StreamEventClassifier {
64 pub const fn new() -> Self {
66 Self {
67 substantial_content_threshold: 50,
68 }
69 }
70
71 pub fn classify(&self, value: &Value) -> ClassificationResult {
79 let Some(obj) = value.as_object() else {
80 return ClassificationResult {
81 event_type: StreamEventType::Complete,
82 type_name: None,
83 content_field: None,
84 };
85 };
86
87 let type_name = obj
88 .get("type")
89 .or_else(|| obj.get("event_type"))
90 .and_then(|v| v.as_str())
91 .map(std::string::ToString::to_string);
92
93 let is_delta = obj
94 .get("delta")
95 .and_then(serde_json::Value::as_bool)
96 .unwrap_or(false);
97
98 if Self::is_control_event(type_name.as_ref(), obj) {
99 return ClassificationResult {
100 event_type: StreamEventType::Control,
101 type_name,
102 content_field: None,
103 };
104 }
105
106 if self.is_partial_event(type_name.as_ref(), obj, is_delta) {
107 return ClassificationResult {
108 event_type: StreamEventType::Partial,
109 type_name,
110 content_field: Self::find_content_field(obj),
111 };
112 }
113
114 ClassificationResult {
115 event_type: StreamEventType::Complete,
116 type_name,
117 content_field: Self::find_content_field(obj),
118 }
119 }
120
121 fn is_control_event(type_name: Option<&String>, obj: &serde_json::Map<String, Value>) -> bool {
122 if let Some(name) = type_name {
123 let control_patterns = [
124 "start",
125 "started",
126 "init",
127 "initialize",
128 "stop",
129 "stopped",
130 "end",
131 "done",
132 "complete",
133 "error",
134 "fail",
135 "failed",
136 "failure",
137 "ping",
138 "pong",
139 "heartbeat",
140 "keepalive",
141 "metadata",
142 "meta",
143 ];
144
145 let name_lower = name.to_lowercase();
146 if control_patterns
147 .iter()
148 .any(|pattern| name_lower.contains(pattern))
149 {
150 return true;
151 }
152 }
153
154 let has_status = obj.contains_key("status") || obj.contains_key("error");
155 let has_content = Self::has_content_field(obj);
156 has_status && !has_content
157 }
158
159 fn is_partial_event(
160 &self,
161 type_name: Option<&String>,
162 obj: &serde_json::Map<String, Value>,
163 explicit_delta: bool,
164 ) -> bool {
165 if explicit_delta {
166 return true;
167 }
168
169 if let Some(name) = type_name {
170 let partial_patterns = [
171 "delta",
172 "partial",
173 "increment",
174 "chunk",
175 "progress",
176 "streaming",
177 "update",
178 ];
179
180 let name_lower = name.to_lowercase();
181 if partial_patterns
182 .iter()
183 .any(|pattern| name_lower.contains(pattern))
184 {
185 return true;
186 }
187 }
188
189 let delta_fields = ["delta", "partial", "increment"];
190 if delta_fields.iter().any(|field| {
191 obj.get(*field).is_some_and(|value| {
192 value.is_string()
193 || value.is_array()
194 || value.is_object()
195 || (value.is_number() && value.as_i64() != Some(0))
196 })
197 }) {
198 return true;
199 }
200
201 if !explicit_delta
202 && (type_name.is_none()
203 || !type_name.as_ref().is_some_and(|n| {
204 let n_lower = n.to_lowercase();
205 n_lower.contains("delta")
206 || n_lower.contains("partial")
207 || n_lower.contains("chunk")
208 }))
209 {
210 if let Some(content) = Self::find_content_field(obj) {
211 if let Some(text) = obj.get(&content).and_then(|v| v.as_str()) {
212 if text.len() < self.substantial_content_threshold {
213 let text_lower = text.to_lowercase();
214 let trimmed = text.trim();
215
216 let complete_responses = [
217 "ok",
218 "okay",
219 "yes",
220 "no",
221 "true",
222 "false",
223 "done",
224 "finished",
225 "complete",
226 "success",
227 "failed",
228 "error",
229 "warning",
230 "info",
231 "debug",
232 "pending",
233 "processing",
234 "running",
235 "none",
236 "null",
237 "empty",
238 ];
239 let is_complete_response = complete_responses.contains(&trimmed);
240
241 let ends_with_terminal = trimmed.ends_with('.')
242 || trimmed.ends_with('!')
243 || trimmed.ends_with('?');
244
245 let has_newline = text.contains('\n');
246
247 let is_error_message = text_lower.contains("error:")
248 || text_lower.contains("warning:")
249 || text_lower.starts_with("error")
250 || text_lower.starts_with("warning");
251
252 if is_complete_response
253 || ends_with_terminal
254 || has_newline
255 || is_error_message
256 {
257 return false;
258 }
259
260 return true;
261 }
262 }
263 }
264 }
265
266 false
267 }
268
269 fn find_content_field(obj: &serde_json::Map<String, Value>) -> Option<String> {
270 let content_fields = [
271 "content",
272 "text",
273 "message",
274 "data",
275 "output",
276 "result",
277 "response",
278 "body",
279 "thinking",
280 "reasoning",
281 "delta",
282 ];
283
284 content_fields
285 .iter()
286 .find(|field| {
287 obj.get(**field)
288 .is_some_and(|v| matches!(v, Value::String(_)))
289 })
290 .map(|f| f.to_string())
291 }
292
293 fn has_content_field(obj: &serde_json::Map<String, Value>) -> bool {
294 Self::find_content_field(obj).is_some()
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use serde_json::json;
302
303 #[test]
304 fn test_classify_delta_event() {
305 let classifier = StreamEventClassifier::new();
306 let event = json!({
307 "type": "content_block_delta",
308 "index": 0,
309 "delta": {"type": "text_delta", "text": "Hello"}
310 });
311
312 let result = classifier.classify(&event);
313 assert_eq!(result.event_type, StreamEventType::Partial);
314 }
315
316 #[test]
317 fn test_classify_control_event() {
318 let classifier = StreamEventClassifier::new();
319 let event = json!({
320 "type": "message_start",
321 "message": {"id": "msg_123"}
322 });
323
324 let result = classifier.classify(&event);
325 assert_eq!(result.event_type, StreamEventType::Control);
326 }
327
328 #[test]
329 fn test_classify_complete_message() {
330 let classifier = StreamEventClassifier::new();
331 let event = json!({
332 "type": "message",
333 "content": "This is a complete message with substantial content that should be displayed as is."
334 });
335
336 let result = classifier.classify(&event);
337 assert_eq!(result.event_type, StreamEventType::Complete);
338 }
339
340 #[test]
341 fn test_classify_explicit_delta_flag() {
342 let classifier = StreamEventClassifier::new();
343 let event = json!({
344 "type": "message",
345 "delta": true,
346 "content": "partial"
347 });
348
349 let result = classifier.classify(&event);
350 assert_eq!(result.event_type, StreamEventType::Partial);
351 }
352
353 #[test]
354 fn test_classify_error_event() {
355 let classifier = StreamEventClassifier::new();
356 let event = json!({
357 "type": "error",
358 "message": "Something went wrong"
359 });
360
361 let result = classifier.classify(&event);
362 assert_eq!(result.event_type, StreamEventType::Control);
363 }
364
365 #[test]
366 fn test_small_content_is_partial() {
367 let classifier = StreamEventClassifier::new();
368 let event = json!({
369 "type": "chunk",
370 "text": "Hi"
371 });
372
373 let result = classifier.classify(&event);
374 assert_eq!(result.event_type, StreamEventType::Partial);
375 }
376
377 #[test]
378 fn test_substantial_content_is_complete() {
379 let classifier = StreamEventClassifier::new();
380 let long_text = "This is a substantial message that exceeds the threshold and should be considered complete.".repeat(2);
381 let event = json!({
382 "type": "message",
383 "content": long_text
384 });
385
386 let result = classifier.classify(&event);
387 assert_eq!(result.event_type, StreamEventType::Complete);
388 }
389}