Skip to main content

prompty_openai/
processor.rs

1//! OpenAI processor — extracts results from OpenAI API responses.
2//!
3//! Handles chat completions, embeddings, image generation, and streaming responses.
4
5use async_trait::async_trait;
6use serde_json::Value;
7
8use prompty::interfaces::{InvokerError, Processor};
9use prompty::model::Prompty;
10use prompty::types::ToolCall;
11
12/// OpenAI processor implementing the `Processor` trait.
13pub struct OpenAIProcessor;
14
15#[async_trait]
16impl Processor for OpenAIProcessor {
17    async fn process(&self, agent: &Prompty, response: Value) -> Result<Value, InvokerError> {
18        process_response(agent, &response)
19    }
20
21    fn process_stream(
22        &self,
23        inner: std::pin::Pin<Box<dyn futures::Stream<Item = Value> + Send>>,
24    ) -> Result<
25        std::pin::Pin<Box<dyn futures::Stream<Item = prompty::types::StreamChunk> + Send>>,
26        InvokerError,
27    > {
28        Ok(process_stream(inner))
29    }
30}
31
32/// Process an OpenAI API response, dispatching by response shape.
33pub fn process_response(agent: &Prompty, response: &Value) -> Result<Value, InvokerError> {
34    // Responses API — has "object" == "response"
35    if response.get("object").and_then(Value::as_str) == Some("response") {
36        return process_responses_api(agent, response);
37    }
38
39    // ChatCompletion — has "choices"
40    if let Some(choices) = response.get("choices").and_then(Value::as_array) {
41        return process_chat_completion(agent, choices);
42    }
43
44    // Embedding — has "data" and "object" == "list"
45    if response.get("object").and_then(Value::as_str) == Some("list") {
46        if let Some(data) = response.get("data").and_then(Value::as_array) {
47            return process_embedding(data);
48        }
49    }
50
51    // Image — has "data" array with url/b64_json
52    if let Some(data) = response.get("data").and_then(Value::as_array) {
53        if data.iter().any(|d| {
54            d.get("url").is_some_and(|v| !v.is_null())
55                || d.get("b64_json").is_some_and(|v| !v.is_null())
56        }) {
57            return process_image(data);
58        }
59    }
60
61    // Unknown response shape — return as-is
62    Ok(response.clone())
63}
64
65// ---------------------------------------------------------------------------
66// Chat completion
67// ---------------------------------------------------------------------------
68
69fn process_chat_completion(agent: &Prompty, choices: &[Value]) -> Result<Value, InvokerError> {
70    let first = choices
71        .first()
72        .ok_or_else(|| InvokerError::Process("Empty choices array".to_string().into()))?;
73
74    let message = first
75        .get("message")
76        .ok_or_else(|| InvokerError::Process("Missing message in choice".to_string().into()))?;
77
78    // Tool calls take priority
79    if let Some(tool_calls) = message.get("tool_calls").and_then(Value::as_array) {
80        if !tool_calls.is_empty() {
81            let calls: Vec<Value> = tool_calls
82                .iter()
83                .map(|tc| {
84                    let func = tc.get("function").unwrap_or(tc);
85                    serde_json::json!({
86                        "id": tc.get("id").and_then(Value::as_str).unwrap_or(""),
87                        "name": func.get("name").and_then(Value::as_str).unwrap_or(""),
88                        "arguments": func.get("arguments").and_then(Value::as_str).unwrap_or("{}"),
89                    })
90                })
91                .collect();
92            return Ok(Value::Array(calls));
93        }
94    }
95
96    // Content
97    let content = message.get("content");
98
99    // Refusal
100    if content.is_none() || content == Some(&Value::Null) {
101        if let Some(refusal) = message.get("refusal").and_then(Value::as_str) {
102            return Ok(Value::String(refusal.to_string()));
103        }
104    }
105
106    let content_str = content.and_then(Value::as_str).unwrap_or("");
107
108    // Structured output: if agent has outputs, try to parse as JSON.
109    // Falls back to raw string gracefully if parsing fails.
110    if let Some(outputs) = agent.as_outputs() {
111        if !outputs.is_empty() {
112            if let Ok(parsed) = serde_json::from_str::<Value>(content_str) {
113                return Ok(parsed);
114            }
115            // Fall through to return raw string
116        }
117    }
118
119    Ok(Value::String(content_str.to_string()))
120}
121
122// ---------------------------------------------------------------------------
123// Responses API (OpenAI new format)
124// ---------------------------------------------------------------------------
125
126fn process_responses_api(agent: &Prompty, response: &Value) -> Result<Value, InvokerError> {
127    // Check for tool calls in output items
128    if let Some(output) = response.get("output").and_then(Value::as_array) {
129        let tool_calls: Vec<Value> = output
130            .iter()
131            .filter(|item| item.get("type").and_then(Value::as_str) == Some("function_call"))
132            .map(|item| {
133                serde_json::json!({
134                    "id": item.get("call_id").and_then(Value::as_str).unwrap_or(""),
135                    "name": item.get("name").and_then(Value::as_str).unwrap_or(""),
136                    "arguments": item.get("arguments").and_then(Value::as_str).unwrap_or("{}"),
137                })
138            })
139            .collect();
140
141        if !tool_calls.is_empty() {
142            return Ok(Value::Array(tool_calls));
143        }
144    }
145
146    // Extract output_text (convenience field)
147    let output_text = response
148        .get("output_text")
149        .and_then(Value::as_str)
150        .unwrap_or("");
151
152    // Structured output
153    if let Some(outputs) = agent.as_outputs() {
154        if !outputs.is_empty() {
155            if let Ok(parsed) = serde_json::from_str::<Value>(output_text) {
156                return Ok(parsed);
157            }
158        }
159    }
160
161    Ok(Value::String(output_text.to_string()))
162}
163
164// ---------------------------------------------------------------------------
165// Embedding
166// ---------------------------------------------------------------------------
167
168fn process_embedding(data: &[Value]) -> Result<Value, InvokerError> {
169    let vectors: Vec<Value> = data
170        .iter()
171        .filter_map(|d| d.get("embedding").cloned())
172        .collect();
173
174    if vectors.len() == 1 {
175        Ok(vectors.into_iter().next().unwrap())
176    } else {
177        Ok(Value::Array(vectors))
178    }
179}
180
181// ---------------------------------------------------------------------------
182// Image
183// ---------------------------------------------------------------------------
184
185fn process_image(data: &[Value]) -> Result<Value, InvokerError> {
186    let urls: Vec<Value> = data
187        .iter()
188        .map(|d| {
189            // Prefer url, fall back to b64_json, skip nulls
190            let url = d.get("url").filter(|v| !v.is_null());
191            let b64 = d.get("b64_json").filter(|v| !v.is_null());
192            url.or(b64).cloned().unwrap_or(Value::Null)
193        })
194        .collect();
195
196    if urls.len() == 1 {
197        Ok(urls.into_iter().next().unwrap())
198    } else {
199        Ok(Value::Array(urls))
200    }
201}
202
203// ---------------------------------------------------------------------------
204// Extract tool calls helper (used by pipeline)
205// ---------------------------------------------------------------------------
206
207/// Try to extract tool calls from a processed response value.
208pub fn extract_tool_calls(response: &Value) -> Option<Vec<ToolCall>> {
209    let arr = response.as_array()?;
210    let calls: Vec<ToolCall> = arr
211        .iter()
212        .filter_map(|v| {
213            let id = v.get("id")?.as_str()?.to_string();
214            let name = v.get("name")?.as_str()?.to_string();
215            let arguments = v.get("arguments")?.as_str()?.to_string();
216            Some(ToolCall {
217                id,
218                name,
219                arguments,
220            })
221        })
222        .collect();
223    if calls.is_empty() { None } else { Some(calls) }
224}
225
226// ---------------------------------------------------------------------------
227// Streaming processor — yields StreamChunk from raw SSE chunks
228// ---------------------------------------------------------------------------
229
230use prompty::types::StreamChunk;
231
232/// Process an OpenAI streaming response (SSE chunks) into a stream of `StreamChunk`s.
233///
234/// Handles three types of streaming deltas:
235/// - `delta.content` — yields `StreamChunk::Text`
236/// - `delta.tool_calls` — accumulates partial tool call chunks,
237///   yields `StreamChunk::Tool` objects when the stream ends
238/// - `delta.refusal` — yields an error as text
239///
240/// Matches TypeScript's `streamGenerator()` in `openai/processor.ts`.
241pub fn process_stream(
242    inner: impl futures::Stream<Item = Value> + Send + Unpin + 'static,
243) -> std::pin::Pin<Box<dyn futures::Stream<Item = StreamChunk> + Send>> {
244    Box::pin(OpenAIStreamProcessor::new(inner))
245}
246
247/// Stream adapter that processes OpenAI SSE chunks into StreamChunks.
248struct OpenAIStreamProcessor {
249    inner: std::pin::Pin<Box<dyn futures::Stream<Item = Value> + Send>>,
250    /// Accumulated partial tool calls, keyed by index.
251    tool_call_acc: std::collections::BTreeMap<usize, (String, String, String)>,
252    /// Phase: Streaming (pulling from inner) or Yielding (emitting accumulated tool calls).
253    phase: StreamPhase,
254    /// Buffer for chunks to yield (content text from a single SSE event can only produce one).
255    pending: std::collections::VecDeque<StreamChunk>,
256}
257
258enum StreamPhase {
259    Streaming,
260    /// Yielding accumulated tool calls, current index.
261    YieldingTools(Vec<ToolCall>, usize),
262    Done,
263}
264
265impl OpenAIStreamProcessor {
266    fn new(inner: impl futures::Stream<Item = Value> + Send + Unpin + 'static) -> Self {
267        Self {
268            inner: Box::pin(inner),
269            tool_call_acc: std::collections::BTreeMap::new(),
270            phase: StreamPhase::Streaming,
271            pending: std::collections::VecDeque::new(),
272        }
273    }
274}
275
276impl futures::Stream for OpenAIStreamProcessor {
277    type Item = StreamChunk;
278
279    fn poll_next(
280        self: std::pin::Pin<&mut Self>,
281        cx: &mut std::task::Context<'_>,
282    ) -> std::task::Poll<Option<Self::Item>> {
283        let this = self.get_mut();
284
285        // Return any pending chunks first
286        if let Some(chunk) = this.pending.pop_front() {
287            return std::task::Poll::Ready(Some(chunk));
288        }
289
290        match &mut this.phase {
291            StreamPhase::Streaming => {
292                match this.inner.as_mut().poll_next(cx) {
293                    std::task::Poll::Ready(Some(chunk)) => {
294                        let delta = chunk
295                            .get("choices")
296                            .and_then(Value::as_array)
297                            .and_then(|c| c.first())
298                            .and_then(|c| c.get("delta"));
299
300                        if let Some(delta) = delta {
301                            // Content text
302                            if let Some(content) = delta.get("content").and_then(Value::as_str) {
303                                if !content.is_empty() {
304                                    return std::task::Poll::Ready(Some(StreamChunk::Text(
305                                        content.to_string(),
306                                    )));
307                                }
308                            }
309
310                            // Tool call deltas
311                            if let Some(tc_deltas) =
312                                delta.get("tool_calls").and_then(Value::as_array)
313                            {
314                                for tc_delta in tc_deltas {
315                                    let idx =
316                                        tc_delta.get("index").and_then(Value::as_u64).unwrap_or(0)
317                                            as usize;
318                                    let entry =
319                                        this.tool_call_acc.entry(idx).or_insert_with(|| {
320                                            (String::new(), String::new(), String::new())
321                                        });
322                                    if let Some(id) = tc_delta.get("id").and_then(Value::as_str) {
323                                        entry.0 = id.to_string();
324                                    }
325                                    if let Some(name) =
326                                        tc_delta.pointer("/function/name").and_then(Value::as_str)
327                                    {
328                                        entry.1 = name.to_string();
329                                    }
330                                    if let Some(args) = tc_delta
331                                        .pointer("/function/arguments")
332                                        .and_then(Value::as_str)
333                                    {
334                                        entry.2.push_str(args);
335                                    }
336                                }
337                            }
338
339                            // Refusal — per spec §10.3: MUST raise error, stream MUST NOT continue
340                            if let Some(refusal) = delta.get("refusal").and_then(Value::as_str) {
341                                if !refusal.is_empty() {
342                                    this.phase = StreamPhase::Done;
343                                    return std::task::Poll::Ready(Some(StreamChunk::Error(
344                                        format!("Model refused: {refusal}"),
345                                    )));
346                                }
347                            }
348                        }
349
350                        // No content from this SSE event, wake and re-poll
351                        cx.waker().wake_by_ref();
352                        std::task::Poll::Pending
353                    }
354                    std::task::Poll::Ready(None) => {
355                        // Inner stream exhausted — yield accumulated tool calls
356                        let tools: Vec<ToolCall> = this
357                            .tool_call_acc
358                            .values()
359                            .map(|(id, name, args)| ToolCall {
360                                id: id.clone(),
361                                name: name.clone(),
362                                arguments: args.clone(),
363                            })
364                            .collect();
365
366                        if tools.is_empty() {
367                            this.phase = StreamPhase::Done;
368                            std::task::Poll::Ready(None)
369                        } else {
370                            let first = tools[0].clone();
371                            this.phase = StreamPhase::YieldingTools(tools, 1);
372                            std::task::Poll::Ready(Some(StreamChunk::Tool(first)))
373                        }
374                    }
375                    std::task::Poll::Pending => std::task::Poll::Pending,
376                }
377            }
378            StreamPhase::YieldingTools(tools, idx) if *idx < tools.len() => {
379                let tc = tools[*idx].clone();
380                *idx += 1;
381                std::task::Poll::Ready(Some(StreamChunk::Tool(tc)))
382            }
383            StreamPhase::YieldingTools(..) => {
384                this.phase = StreamPhase::Done;
385                std::task::Poll::Ready(None)
386            }
387            StreamPhase::Done => std::task::Poll::Ready(None),
388        }
389    }
390}
391
392// ---------------------------------------------------------------------------
393// Tests
394// ---------------------------------------------------------------------------
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399    use prompty::model::context::LoadContext;
400    use serde_json::json;
401
402    fn make_agent(outputs_json: Value) -> Prompty {
403        let mut data = json!({
404            "name": "test",
405            "kind": "prompt",
406            "model": {"id": "gpt-4"},
407            "instructions": "test",
408        });
409        if !outputs_json.is_null() {
410            data["outputs"] = outputs_json;
411        }
412        Prompty::load_from_value(&data, &LoadContext::default())
413    }
414
415    #[test]
416    fn test_process_chat_content() {
417        let agent = make_agent(Value::Null);
418        let response = json!({
419            "choices": [{
420                "message": {
421                    "role": "assistant",
422                    "content": "Hello!"
423                }
424            }]
425        });
426        let result = process_response(&agent, &response).unwrap();
427        assert_eq!(result, json!("Hello!"));
428    }
429
430    #[test]
431    fn test_process_chat_tool_calls() {
432        let agent = make_agent(Value::Null);
433        let response = json!({
434            "choices": [{
435                "message": {
436                    "role": "assistant",
437                    "content": null,
438                    "tool_calls": [{
439                        "id": "call_1",
440                        "type": "function",
441                        "function": {
442                            "name": "get_weather",
443                            "arguments": "{\"city\":\"SF\"}"
444                        }
445                    }]
446                }
447            }]
448        });
449        let result = process_response(&agent, &response).unwrap();
450        let calls = result.as_array().unwrap();
451        assert_eq!(calls.len(), 1);
452        assert_eq!(calls[0]["name"], "get_weather");
453        assert_eq!(calls[0]["id"], "call_1");
454    }
455
456    #[test]
457    fn test_process_chat_refusal() {
458        let agent = make_agent(Value::Null);
459        let response = json!({
460            "choices": [{
461                "message": {
462                    "role": "assistant",
463                    "content": null,
464                    "refusal": "I can't do that"
465                }
466            }]
467        });
468        let result = process_response(&agent, &response).unwrap();
469        assert_eq!(result, json!("I can't do that"));
470    }
471
472    #[test]
473    fn test_process_structured_output() {
474        let agent = make_agent(json!([
475            {"name": "answer", "kind": "string", "required": true}
476        ]));
477        let response = json!({
478            "choices": [{
479                "message": {
480                    "role": "assistant",
481                    "content": "{\"answer\": \"42\"}"
482                }
483            }]
484        });
485        let result = process_response(&agent, &response).unwrap();
486        assert_eq!(result["answer"], "42");
487    }
488
489    #[test]
490    fn test_process_embedding_single() {
491        let agent = make_agent(Value::Null);
492        let response = json!({
493            "object": "list",
494            "data": [{
495                "object": "embedding",
496                "embedding": [0.1, 0.2, 0.3]
497            }]
498        });
499        let result = process_response(&agent, &response).unwrap();
500        assert_eq!(result, json!([0.1, 0.2, 0.3]));
501    }
502
503    #[test]
504    fn test_process_embedding_multiple() {
505        let agent = make_agent(Value::Null);
506        let response = json!({
507            "object": "list",
508            "data": [
509                {"object": "embedding", "embedding": [0.1, 0.2]},
510                {"object": "embedding", "embedding": [0.3, 0.4]}
511            ]
512        });
513        let result = process_response(&agent, &response).unwrap();
514        assert_eq!(result, json!([[0.1, 0.2], [0.3, 0.4]]));
515    }
516
517    #[test]
518    fn test_process_image_single() {
519        let agent = make_agent(Value::Null);
520        let response = json!({
521            "data": [{"url": "https://example.com/image.png"}]
522        });
523        let result = process_response(&agent, &response).unwrap();
524        assert_eq!(result, json!("https://example.com/image.png"));
525    }
526
527    #[test]
528    fn test_process_image_multiple() {
529        let agent = make_agent(Value::Null);
530        let response = json!({
531            "data": [
532                {"url": "https://a.png"},
533                {"url": "https://b.png"}
534            ]
535        });
536        let result = process_response(&agent, &response).unwrap();
537        assert_eq!(result, json!(["https://a.png", "https://b.png"]));
538    }
539
540    #[test]
541    fn test_extract_tool_calls() {
542        let val = json!([
543            {"id": "c1", "name": "fn1", "arguments": "{}"},
544            {"id": "c2", "name": "fn2", "arguments": "{\"x\":1}"}
545        ]);
546        let calls = extract_tool_calls(&val).unwrap();
547        assert_eq!(calls.len(), 2);
548        assert_eq!(calls[0].name, "fn1");
549        assert_eq!(calls[1].name, "fn2");
550    }
551
552    #[test]
553    fn test_extract_tool_calls_not_tool_response() {
554        assert!(extract_tool_calls(&json!("Hello")).is_none());
555        assert!(extract_tool_calls(&json!(42)).is_none());
556    }
557
558    // -----------------------------------------------------------------------
559    // Edge cases: empty choices, missing message, malformed tool calls
560    // -----------------------------------------------------------------------
561
562    #[test]
563    fn test_empty_choices_error() {
564        let agent = Prompty::default();
565        let response = json!({
566            "choices": []
567        });
568        let err = process_response(&agent, &response).unwrap_err();
569        assert!(err.to_string().contains("Empty choices"));
570    }
571
572    #[test]
573    fn test_missing_message_error() {
574        let agent = Prompty::default();
575        let response = json!({
576            "choices": [{"finish_reason": "stop"}]
577        });
578        let err = process_response(&agent, &response).unwrap_err();
579        assert!(err.to_string().contains("Missing message"));
580    }
581
582    #[test]
583    fn test_tool_calls_with_missing_fields() {
584        let agent = Prompty::default();
585        // Tool calls where some entries are malformed
586        let response = json!({
587            "choices": [{
588                "message": {
589                    "tool_calls": [
590                        {
591                            "id": "call_1",
592                            "function": {"name": "test", "arguments": "{}"}
593                        },
594                        {
595                            // Missing function block — should still extract with empty defaults
596                            "id": "call_2"
597                        }
598                    ]
599                }
600            }]
601        });
602        let result = process_response(&agent, &response).unwrap();
603        let arr = result.as_array().unwrap();
604        assert_eq!(arr.len(), 2);
605        assert_eq!(arr[0]["name"], "test");
606        // Malformed entry should have empty defaults
607        assert_eq!(arr[1]["name"], "");
608    }
609
610    #[test]
611    fn test_null_content_no_refusal() {
612        let agent = Prompty::default();
613        let response = json!({
614            "choices": [{
615                "message": {
616                    "content": null
617                }
618            }]
619        });
620        let result = process_response(&agent, &response).unwrap();
621        assert_eq!(result, "");
622    }
623
624    #[test]
625    fn test_unknown_response_shape_passthrough() {
626        let agent = Prompty::default();
627        let response = json!({
628            "unexpected": "format",
629            "custom": 42
630        });
631        let result = process_response(&agent, &response).unwrap();
632        assert_eq!(result, response);
633    }
634
635    #[test]
636    fn test_extract_tool_calls_empty_array() {
637        // Empty array should return None
638        assert!(extract_tool_calls(&json!([])).is_none());
639    }
640
641    #[test]
642    fn test_extract_tool_calls_array_with_non_tool_objects() {
643        // Array of objects without proper tool call fields
644        let val = json!([{"foo": "bar"}, {"baz": 42}]);
645        assert!(extract_tool_calls(&val).is_none());
646    }
647
648    #[test]
649    fn test_structured_output_invalid_json_falls_back() {
650        // Agent with outputs, but content is not valid JSON — falls back to raw string
651        let data = serde_json::json!({
652            "kind": "prompt",
653            "name": "structured",
654            "model": "gpt-4",
655            "outputs": [{"name": "result", "kind": "object"}],
656            "instructions": "Return JSON"
657        });
658        let agent = Prompty::load_from_value(&data, &LoadContext::default());
659        let response = json!({
660            "choices": [{
661                "message": {
662                    "content": "this is not json"
663                }
664            }]
665        });
666        let result = process_response(&agent, &response).unwrap();
667        assert_eq!(result, "this is not json");
668    }
669
670    #[test]
671    fn test_embedding_multiple_vectors() {
672        let agent = Prompty::default();
673        let response = json!({
674            "object": "list",
675            "data": [
676                {"embedding": [0.1, 0.2]},
677                {"embedding": [0.3, 0.4]}
678            ]
679        });
680        let result = process_response(&agent, &response).unwrap();
681        let arr = result.as_array().unwrap();
682        assert_eq!(arr.len(), 2);
683    }
684
685    #[test]
686    fn test_image_multiple_urls() {
687        let agent = Prompty::default();
688        let response = json!({
689            "data": [
690                {"url": "https://a.com/1.png"},
691                {"url": "https://a.com/2.png"}
692            ]
693        });
694        let result = process_response(&agent, &response).unwrap();
695        let arr = result.as_array().unwrap();
696        assert_eq!(arr.len(), 2);
697    }
698
699    // -----------------------------------------------------------------------
700    // Streaming processor tests
701    // -----------------------------------------------------------------------
702
703    #[tokio::test]
704    async fn test_stream_text_content() {
705        use futures::StreamExt;
706        let chunks = vec![
707            json!({"choices": [{"delta": {"content": "Hello"}}]}),
708            json!({"choices": [{"delta": {"content": " world"}}]}),
709            json!({"choices": [{"delta": {}}]}), // empty delta
710        ];
711        let inner = futures::stream::iter(chunks);
712        let mut stream = process_stream(inner);
713        let mut texts = Vec::new();
714        while let Some(chunk) = stream.next().await {
715            match chunk {
716                StreamChunk::Text(t) => texts.push(t),
717                StreamChunk::Tool(_) => panic!("unexpected tool call"),
718                _ => {}
719            }
720        }
721        assert_eq!(texts.join(""), "Hello world");
722    }
723
724    #[tokio::test]
725    async fn test_stream_tool_calls() {
726        use futures::StreamExt;
727        let chunks = vec![
728            json!({"choices": [{"delta": {"tool_calls": [
729                {"index": 0, "id": "call_1", "function": {"name": "get_weather", "arguments": "{\"ci"}}
730            ]}}]}),
731            json!({"choices": [{"delta": {"tool_calls": [
732                {"index": 0, "function": {"arguments": "ty\":\"SF\"}"}}
733            ]}}]}),
734        ];
735        let inner = futures::stream::iter(chunks);
736        let mut stream = process_stream(inner);
737        let mut tools = Vec::new();
738        while let Some(chunk) = stream.next().await {
739            match chunk {
740                StreamChunk::Text(_) => {}
741                StreamChunk::Tool(tc) => tools.push(tc),
742                _ => {}
743            }
744        }
745        assert_eq!(tools.len(), 1);
746        assert_eq!(tools[0].id, "call_1");
747        assert_eq!(tools[0].name, "get_weather");
748        assert_eq!(tools[0].arguments, "{\"city\":\"SF\"}");
749    }
750
751    #[tokio::test]
752    async fn test_stream_refusal() {
753        use futures::StreamExt;
754        let chunks = vec![json!({"choices": [{"delta": {"refusal": "I cannot help with that"}}]})];
755        let inner = futures::stream::iter(chunks);
756        let mut stream = process_stream(inner);
757        let mut errors = Vec::new();
758        while let Some(chunk) = stream.next().await {
759            if let StreamChunk::Error(e) = chunk {
760                errors.push(e);
761            }
762        }
763        assert_eq!(errors.len(), 1);
764        assert!(errors[0].contains("refused"));
765    }
766
767    #[tokio::test]
768    async fn test_stream_with_consume() {
769        use prompty::types::consume_stream_chunks;
770        let chunks = vec![
771            json!({"choices": [{"delta": {"content": "Hello"}}]}),
772            json!({"choices": [{"delta": {"content": " "}}]}),
773            json!({"choices": [{"delta": {"content": "world"}}]}),
774        ];
775        let inner = futures::stream::iter(chunks);
776        let stream = process_stream(inner);
777        let (tool_calls, content) = consume_stream_chunks(stream, None).await;
778        assert!(tool_calls.is_empty());
779        assert_eq!(content, "Hello world");
780    }
781
782    #[tokio::test]
783    async fn test_stream_mixed_content_then_tools() {
784        use futures::StreamExt;
785        // Some providers may send content then tool calls
786        let chunks = vec![
787            json!({"choices": [{"delta": {"content": "Let me check..."}}]}),
788            json!({"choices": [{"delta": {"tool_calls": [
789                {"index": 0, "id": "c1", "function": {"name": "search", "arguments": "{}"}}
790            ]}}]}),
791        ];
792        let inner = futures::stream::iter(chunks);
793        let mut stream = process_stream(inner);
794        let mut texts = Vec::new();
795        let mut tools = Vec::new();
796        while let Some(chunk) = stream.next().await {
797            match chunk {
798                StreamChunk::Text(t) => texts.push(t),
799                StreamChunk::Tool(tc) => tools.push(tc),
800                _ => {}
801            }
802        }
803        assert_eq!(texts.join(""), "Let me check...");
804        assert_eq!(tools.len(), 1);
805        assert_eq!(tools[0].name, "search");
806    }
807}