Skip to main content

openai_reassembler/
lib.rs

1//! Reassemble OpenAI-compatible SSE streaming responses into non-streaming format.
2//!
3//! When an OpenAI-compatible API streams a response as Server-Sent Events (SSE),
4//! each event contains a partial "chunk" of the final response. This crate provides
5//! a function to merge those chunks into the equivalent non-streaming JSON response.
6//!
7//! # Supported formats
8//!
9//! - **Chat completions** (`/v1/chat/completions`): merges `choices[].delta` fields
10//!   into `choices[].message`, concatenating string values (e.g. `content`, `refusal`)
11//!   and assembling `tool_calls` by index. Other non-string delta fields use last-value-wins.
12//! - **Legacy completions** (`/v1/completions`): concatenates `choices[].text`.
13//! - **Responses API** (`/v1/responses`): extracts the full response from the
14//!   `response.completed` event.
15//! - **Multiple choices**: tracked independently by `index`.
16//! - **Usage**: taken from the final chunk.
17//!
18//! Format detection is automatic: if any event's `event` field (from
19//! `eventsource_stream::Event`) starts with `"response."`, the Responses API
20//! path is used; otherwise the completions path.
21
22use serde_json::{Map, Value};
23
24/// Reassemble OpenAI-compatible streaming chunks into a non-streaming response.
25///
26/// Auto-detects the stream format and dispatches accordingly:
27/// - **Responses API**: if any event's `event` field starts with `"response."`,
28///   extracts the full response from the `response.completed` event.
29/// - **Completions**: otherwise, merges `choices[].delta` / `choices[].text` chunks.
30///
31/// For the chat and legacy completions endpoints, top-level fields (`id`, `created`,
32/// `model`, etc.) are taken from the first chunk. In this completions path, the
33/// `object` field has the `.chunk` suffix stripped (e.g. `chat.completion.chunk`
34/// → `chat.completion`). Responses API objects are left unchanged.
35///
36/// Events with empty data or `[DONE]` are skipped.
37pub fn reassemble(events: &[eventsource_stream::Event]) -> anyhow::Result<String> {
38    let is_responses_api = events.iter().any(|e| e.event.starts_with("response."));
39    if is_responses_api {
40        return reassemble_responses(events);
41    }
42
43    let mut base: Option<Value> = None;
44    let mut choices: std::collections::BTreeMap<u64, Map<String, Value>> = Default::default();
45    let mut usage = Value::Null;
46
47    for event in events {
48        if event.data.is_empty() || event.data == "[DONE]" {
49            continue;
50        }
51        let chunk: Value = serde_json::from_str(&event.data)
52            .map_err(|e| anyhow::anyhow!("Invalid chunk JSON: {}", e))?;
53
54        if base.is_none() {
55            let mut b = chunk.clone();
56            if let Some(obj) = b["object"].as_str() {
57                b["object"] = Value::String(obj.replace(".chunk", ""));
58            }
59            if let Some(m) = b.as_object_mut() {
60                m.remove("choices");
61                m.remove("usage");
62            }
63            base = Some(b);
64        }
65
66        if !chunk["usage"].is_null() {
67            usage = chunk["usage"].clone();
68        }
69
70        if let Some(chunk_choices) = chunk["choices"].as_array() {
71            for choice in chunk_choices {
72                let index = choice["index"].as_u64().unwrap_or(0);
73                let merged = choices.entry(index).or_default();
74
75                if !choice["finish_reason"].is_null() {
76                    merged.insert("finish_reason".to_string(), choice["finish_reason"].clone());
77                }
78
79                // Legacy completions: concatenate "text"
80                if let Some(text) = choice["text"].as_str() {
81                    let existing = merged
82                        .entry("text".to_string())
83                        .or_insert(Value::String(String::new()));
84                    if let Value::String(s) = existing {
85                        s.push_str(text);
86                    }
87                }
88
89                // Chat completions: merge "delta" into "message"
90                if let Some(delta) = choice["delta"].as_object() {
91                    let message = merged
92                        .entry("message".to_string())
93                        .or_insert(Value::Object(Map::new()));
94                    if let Value::Object(msg) = message {
95                        for (key, value) in delta {
96                            if value.is_null() {
97                                continue;
98                            }
99                            match key.as_str() {
100                                "tool_calls" => merge_tool_calls(msg, value),
101                                _ => merge_delta_field(msg, key, value),
102                            }
103                        }
104                    }
105                }
106            }
107        }
108    }
109
110    let mut response = base.unwrap_or(Value::Object(Map::new()));
111    let assembled_choices: Vec<Value> = choices
112        .into_iter()
113        .map(|(index, mut fields)| {
114            fields.insert("index".to_string(), Value::Number(index.into()));
115            if !fields.contains_key("finish_reason") {
116                fields.insert("finish_reason".to_string(), Value::Null);
117            }
118            Value::Object(fields)
119        })
120        .collect();
121    response["choices"] = Value::Array(assembled_choices);
122    response["usage"] = usage;
123
124    Ok(response.to_string())
125}
126
127/// Reassemble a Responses API SSE stream into a non-streaming response.
128///
129/// The Responses API emits typed events (`response.created`, `response.output_text.delta`,
130/// etc.). The final `response.completed` event contains the full response object under
131/// the `"response"` key. This function finds that event and extracts the response.
132fn reassemble_responses(events: &[eventsource_stream::Event]) -> anyhow::Result<String> {
133    for event in events.iter().rev() {
134        if event.event == "response.completed" {
135            let parsed: Value = serde_json::from_str(&event.data)
136                .map_err(|e| anyhow::anyhow!("Invalid response.completed JSON: {}", e))?;
137            if let Some(response) = parsed.get("response") {
138                return serde_json::to_string(response).map_err(Into::into);
139            }
140            anyhow::bail!(
141                "response.completed event JSON does not contain top-level \"response\" field"
142            );
143        }
144    }
145    anyhow::bail!("No response.completed event found in Responses API SSE stream")
146}
147
148/// Merge streamed tool_calls deltas into the accumulated message.
149///
150/// Tool calls arrive as an array of deltas, each with an `index` field indicating
151/// which tool call slot they belong to. `id` and `type` are set once; `function.name`
152/// and `function.arguments` are concatenated across chunks.
153fn merge_tool_calls(msg: &mut Map<String, Value>, value: &Value) {
154    let Some(arr) = value.as_array() else { return };
155    let tc_list = msg
156        .entry("tool_calls".to_string())
157        .or_insert(Value::Array(vec![]));
158    let Value::Array(existing) = tc_list else {
159        return;
160    };
161
162    for tc_delta in arr {
163        let idx = tc_delta["index"].as_u64().unwrap_or(0) as usize;
164        while existing.len() <= idx {
165            existing.push(Value::Object(Map::new()));
166        }
167        let slot = existing[idx].as_object_mut().unwrap();
168
169        // Set id and type (arrive once, on the first delta for this tool call)
170        for field in ["id", "type"] {
171            if let Some(v) = tc_delta.get(field) {
172                if !v.is_null() {
173                    slot.insert(field.to_string(), v.clone());
174                }
175            }
176        }
177
178        // Concatenate function name and arguments
179        if let Some(func) = tc_delta["function"].as_object() {
180            let f = slot
181                .entry("function".to_string())
182                .or_insert(Value::Object(Map::new()))
183                .as_object_mut()
184                .unwrap();
185            for field in ["name", "arguments"] {
186                if let Some(s) = func.get(field).and_then(|v| v.as_str()) {
187                    let existing = f
188                        .entry(field.to_string())
189                        .or_insert(Value::String(String::new()));
190                    if let Value::String(es) = existing {
191                        es.push_str(s);
192                    }
193                }
194            }
195        }
196    }
197}
198
199/// Merge a single delta field into the accumulated message.
200///
201/// String fields (content, role, refusal, etc.) are concatenated.
202/// Non-string fields use last-value-wins.
203fn merge_delta_field(msg: &mut Map<String, Value>, key: &str, value: &Value) {
204    if let Some(s) = value.as_str() {
205        let existing = msg
206            .entry(key.to_string())
207            .or_insert(Value::String(String::new()));
208        if let Value::String(existing_str) = existing {
209            existing_str.push_str(s);
210        }
211    } else {
212        msg.insert(key.to_string(), value.clone());
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use std::path::PathBuf;
220    use std::sync::Once;
221
222    static GENERATE: Once = Once::new();
223
224    /// If BASE_URL and MODEL are set, generate all fixtures once before tests run.
225    fn ensure_fixtures() {
226        GENERATE.call_once(|| {
227            let (Ok(base_url), Ok(model)) = (std::env::var("BASE_URL"), std::env::var("MODEL"))
228            else {
229                return;
230            };
231            let api_key = std::env::var("API_KEY").unwrap_or_else(|_| "none".to_string());
232            let root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
233            let fixtures_dir = root.join("fixtures");
234            std::fs::create_dir_all(&fixtures_dir).unwrap();
235
236            let cases: Value = serde_json::from_str(
237                &std::fs::read_to_string(root.join("test_cases.json")).unwrap(),
238            )
239            .unwrap();
240
241            let rt = tokio::runtime::Runtime::new().unwrap();
242            let client = reqwest::Client::new();
243
244            for (name, case) in cases.as_object().unwrap() {
245                let endpoint = case["endpoint"].as_str().unwrap();
246                if endpoint.ends_with("/responses") {
247                    rt.block_on(record_responses_fixture(
248                        &client,
249                        &base_url,
250                        &api_key,
251                        &model,
252                        name,
253                        case,
254                        &fixtures_dir,
255                    ));
256                } else {
257                    rt.block_on(record_fixture(
258                        &client,
259                        &base_url,
260                        &api_key,
261                        &model,
262                        name,
263                        case,
264                        &fixtures_dir,
265                    ));
266                }
267            }
268        });
269    }
270
271    async fn record_fixture(
272        client: &reqwest::Client,
273        base_url: &str,
274        api_key: &str,
275        model: &str,
276        name: &str,
277        case: &Value,
278        fixtures_dir: &PathBuf,
279    ) {
280        let endpoint = case["endpoint"].as_str().unwrap();
281        let url = format!("{base_url}{endpoint}");
282        let mut body = case["body"].as_object().unwrap().clone();
283        body.insert("model".to_string(), Value::String(model.to_string()));
284        body.insert("temperature".to_string(), Value::Number(0.into()));
285        body.insert("seed".to_string(), Value::Number(42.into()));
286
287        // Non-streaming
288        let mut non_stream_body = body.clone();
289        non_stream_body.insert("stream".to_string(), Value::Bool(false));
290        eprintln!("[{name}] POST {url} (non-streaming)");
291        let expected: Value = client
292            .post(&url)
293            .bearer_auth(api_key)
294            .json(&non_stream_body)
295            .send()
296            .await
297            .unwrap_or_else(|e| panic!("{name}: non-streaming request failed: {e}"))
298            .json()
299            .await
300            .unwrap_or_else(|e| panic!("{name}: non-streaming parse failed: {e}"));
301        eprintln!("[{name}] non-streaming response received");
302
303        // Streaming
304        let mut stream_body = body.clone();
305        stream_body.insert("stream".to_string(), Value::Bool(true));
306        let mut stream_opts = serde_json::Map::new();
307        stream_opts.insert("include_usage".to_string(), Value::Bool(true));
308        stream_body.insert("stream_options".to_string(), Value::Object(stream_opts));
309
310        eprintln!("[{name}] POST {url} (streaming)");
311        let response_text = client
312            .post(&url)
313            .bearer_auth(api_key)
314            .json(&stream_body)
315            .send()
316            .await
317            .unwrap_or_else(|e| panic!("{name}: streaming request failed: {e}"))
318            .text()
319            .await
320            .unwrap_or_else(|e| panic!("{name}: streaming read failed: {e}"));
321
322        let mut chunks: Vec<Value> = vec![];
323        for line in response_text.lines() {
324            if let Some(data) = line.strip_prefix("data: ") {
325                if data == "[DONE]" {
326                    chunks.push(Value::String("[DONE]".to_string()));
327                } else if let Ok(parsed) = serde_json::from_str::<Value>(data) {
328                    chunks.push(parsed);
329                }
330            }
331        }
332
333        eprintln!("[{name}] streaming response: {} chunks", chunks.len());
334
335        let fixture = serde_json::json!({ "chunks": chunks, "expected": expected });
336        let path = fixtures_dir.join(format!("{name}.json"));
337        std::fs::write(
338            &path,
339            serde_json::to_string_pretty(&fixture).unwrap() + "\n",
340        )
341        .unwrap_or_else(|e| panic!("{name}: failed to write fixture: {e}"));
342        eprintln!("[{name}] fixture written to {}", path.display());
343    }
344
345    /// Record a fixture for the Responses API.
346    ///
347    /// Unlike completions, responses SSE events are typed (e.g. `event: response.created`)
348    /// and the non-streaming request omits `stream` entirely rather than setting it to false.
349    /// Usage is always included in `response.completed` without needing `stream_options`.
350    async fn record_responses_fixture(
351        client: &reqwest::Client,
352        base_url: &str,
353        api_key: &str,
354        model: &str,
355        name: &str,
356        case: &Value,
357        fixtures_dir: &PathBuf,
358    ) {
359        let endpoint = case["endpoint"].as_str().unwrap();
360        let url = format!("{base_url}{endpoint}");
361        let mut body = case["body"].as_object().unwrap().clone();
362        body.insert("model".to_string(), Value::String(model.to_string()));
363        body.insert("temperature".to_string(), Value::Number(0.into()));
364        body.insert("seed".to_string(), Value::Number(42.into()));
365
366        // Non-streaming (no stream field at all for responses API)
367        eprintln!("[{name}] POST {url} (non-streaming)");
368        let expected: Value = client
369            .post(&url)
370            .bearer_auth(api_key)
371            .json(&body)
372            .send()
373            .await
374            .unwrap_or_else(|e| panic!("{name}: non-streaming request failed: {e}"))
375            .json()
376            .await
377            .unwrap_or_else(|e| panic!("{name}: non-streaming parse failed: {e}"));
378        eprintln!("[{name}] non-streaming response received");
379
380        // Streaming
381        body.insert("stream".to_string(), Value::Bool(true));
382
383        eprintln!("[{name}] POST {url} (streaming)");
384        let response_text = client
385            .post(&url)
386            .bearer_auth(api_key)
387            .json(&body)
388            .send()
389            .await
390            .unwrap_or_else(|e| panic!("{name}: streaming request failed: {e}"))
391            .text()
392            .await
393            .unwrap_or_else(|e| panic!("{name}: streaming read failed: {e}"));
394
395        // Parse SSE events preserving event types (spec-compliant: accumulate
396        // data lines until a blank line delimits the event).
397        let mut events: Vec<Value> = vec![];
398        let mut current_event_type: Option<String> = None;
399        let mut current_data_lines: Vec<String> = Vec::new();
400
401        for raw_line in response_text.lines() {
402            let line = raw_line.trim_end_matches('\r');
403            if line.is_empty() {
404                if !current_data_lines.is_empty() {
405                    let data_str = current_data_lines.join("\n");
406                    if data_str != "[DONE]" {
407                        if let Ok(parsed) = serde_json::from_str::<Value>(&data_str) {
408                            let event_type = current_event_type.clone().unwrap_or_default();
409                            events.push(
410                                serde_json::json!({ "event_type": event_type, "data": parsed }),
411                            );
412                        }
413                    }
414                }
415                current_event_type = None;
416                current_data_lines.clear();
417            } else if let Some(event_type) = line
418                .strip_prefix("event: ")
419                .or_else(|| line.strip_prefix("event:"))
420            {
421                current_event_type = Some(event_type.to_string());
422            } else if let Some(data) = line
423                .strip_prefix("data: ")
424                .or_else(|| line.strip_prefix("data:"))
425            {
426                current_data_lines.push(data.to_string());
427            }
428        }
429
430        // Finalize any event not terminated by a trailing blank line
431        if !current_data_lines.is_empty() {
432            let data_str = current_data_lines.join("\n");
433            if data_str != "[DONE]" {
434                if let Ok(parsed) = serde_json::from_str::<Value>(&data_str) {
435                    let event_type = current_event_type.clone().unwrap_or_default();
436                    events.push(
437                        serde_json::json!({ "event_type": event_type, "data": parsed }),
438                    );
439                }
440            }
441        }
442
443        eprintln!("[{name}] streaming response: {} events", events.len());
444
445        let fixture = serde_json::json!({ "events": events, "expected": expected });
446        let path = fixtures_dir.join(format!("{name}.json"));
447        std::fs::write(
448            &path,
449            serde_json::to_string_pretty(&fixture).unwrap() + "\n",
450        )
451        .unwrap_or_else(|e| panic!("{name}: failed to write fixture: {e}"));
452        eprintln!("[{name}] fixture written to {}", path.display());
453    }
454
455    /// Recursively compare two JSON values, collecting mismatches.
456    /// Fields in `skip` are skipped at any nesting depth.
457    fn diff(
458        actual: &Value,
459        expected: &Value,
460        path: &str,
461        skip: &[String],
462        errors: &mut Vec<String>,
463    ) {
464        match (actual, expected) {
465            (Value::Object(a), Value::Object(e)) => {
466                for (key, ev) in e {
467                    if skip.iter().any(|s| s == key) {
468                        continue;
469                    }
470                    let p = if path.is_empty() {
471                        key.clone()
472                    } else {
473                        format!("{path}.{key}")
474                    };
475                    match a.get(key) {
476                        Some(av) => diff(av, ev, &p, skip, errors),
477                        None => errors.push(format!("{p}: missing from reassembled output")),
478                    }
479                }
480                for key in a.keys() {
481                    if skip.iter().any(|s| s == key) {
482                        continue;
483                    }
484                    if !e.contains_key(key) {
485                        let p = if path.is_empty() {
486                            key.clone()
487                        } else {
488                            format!("{path}.{key}")
489                        };
490                        errors.push(format!("{p}: unexpected field in reassembled output"));
491                    }
492                }
493            }
494            (Value::Array(a), Value::Array(e)) => {
495                if a.len() != e.len() {
496                    errors.push(format!(
497                        "{path}: array length {}, expected {}",
498                        a.len(),
499                        e.len()
500                    ));
501                    return;
502                }
503                for (i, (av, ev)) in a.iter().zip(e).enumerate() {
504                    diff(av, ev, &format!("{path}[{i}]"), skip, errors);
505                }
506            }
507            _ => {
508                if actual != expected {
509                    // Tool call arguments: compare as parsed JSON (whitespace may differ)
510                    if path.ends_with(".arguments") {
511                        if let (Some(a), Some(e)) = (actual.as_str(), expected.as_str()) {
512                            let ap: Result<Value, _> = serde_json::from_str(a);
513                            let ep: Result<Value, _> = serde_json::from_str(e);
514                            if let (Ok(ap), Ok(ep)) = (ap, ep) {
515                                if ap == ep {
516                                    return;
517                                }
518                            }
519                        }
520                    }
521                    errors.push(format!("{path}: got {actual}, expected {expected}"));
522                }
523            }
524        }
525    }
526
527    fn assert_fixture(name: &str) {
528        ensure_fixtures();
529        let root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
530
531        // Load allowed_mismatches from test_cases.json
532        let cases: Value =
533            serde_json::from_str(&std::fs::read_to_string(root.join("test_cases.json")).unwrap())
534                .unwrap();
535        let skip: Vec<String> = cases[name]["allowed_mismatches"]
536            .as_array()
537            .map(|a| a.iter().map(|v| v.as_str().unwrap().to_string()).collect())
538            .unwrap_or_default();
539
540        let path = root.join("fixtures").join(format!("{name}.json"));
541        let content = std::fs::read_to_string(&path)
542            .unwrap_or_else(|e| panic!("missing fixture {}: {e}", path.display()));
543        let fixture: Value = serde_json::from_str(&content).unwrap();
544
545        let events: Vec<eventsource_stream::Event> = fixture["chunks"]
546            .as_array()
547            .unwrap()
548            .iter()
549            .map(|chunk| eventsource_stream::Event {
550                data: if chunk.is_string() {
551                    chunk.as_str().unwrap().to_string()
552                } else {
553                    chunk.to_string()
554                },
555                ..Default::default()
556            })
557            .collect();
558
559        let actual: Value = serde_json::from_str(&reassemble(&events).unwrap()).unwrap();
560
561        let mut errors = vec![];
562        diff(&actual, &fixture["expected"], "", &skip, &mut errors);
563        if !errors.is_empty() {
564            panic!("fixture {name}:\n{}", errors.join("\n"));
565        }
566    }
567
568    /// Load a Responses API fixture and verify reassembly matches the expected response.
569    ///
570    /// Responses fixtures store events as `{ "event_type": ..., "data": ... }` objects
571    /// under the `"events"` key (not `"chunks"`).
572    fn assert_responses_fixture(name: &str) {
573        ensure_fixtures();
574        let root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
575
576        let cases: Value =
577            serde_json::from_str(&std::fs::read_to_string(root.join("test_cases.json")).unwrap())
578                .unwrap();
579        let skip: Vec<String> = cases[name]["allowed_mismatches"]
580            .as_array()
581            .map(|a| a.iter().map(|v| v.as_str().unwrap().to_string()).collect())
582            .unwrap_or_default();
583
584        let path = root.join("fixtures").join(format!("{name}.json"));
585        let content = std::fs::read_to_string(&path)
586            .unwrap_or_else(|e| panic!("missing fixture {}: {e}", path.display()));
587        let fixture: Value = serde_json::from_str(&content).unwrap();
588
589        let events: Vec<eventsource_stream::Event> = fixture["events"]
590            .as_array()
591            .unwrap()
592            .iter()
593            .map(|ev| eventsource_stream::Event {
594                event: ev["event_type"]
595                    .as_str()
596                    .unwrap_or_default()
597                    .to_string(),
598                data: ev["data"].to_string(),
599                ..Default::default()
600            })
601            .collect();
602
603        let actual: Value = serde_json::from_str(&reassemble(&events).unwrap()).unwrap();
604
605        let mut errors = vec![];
606        diff(&actual, &fixture["expected"], "", &skip, &mut errors);
607        if !errors.is_empty() {
608            panic!("fixture {name}:\n{}", errors.join("\n"));
609        }
610    }
611
612    macro_rules! fixture_test {
613        ($name:ident) => {
614            #[test]
615            fn $name() {
616                assert_fixture(stringify!($name));
617            }
618        };
619    }
620
621    macro_rules! responses_fixture_test {
622        ($name:ident) => {
623            #[test]
624            fn $name() {
625                assert_responses_fixture(stringify!($name));
626            }
627        };
628    }
629
630    fixture_test!(chat_simple);
631    fixture_test!(chat_tools);
632    fixture_test!(legacy_completion);
633    responses_fixture_test!(responses_simple);
634}