Skip to main content

alef_e2e/codegen/
streaming_assertions.rs

1//! Shared streaming-virtual-fields module for e2e test codegen.
2//!
3//! Chat-stream fixtures assert on "virtual" fields that don't exist on the
4//! stream result type itself — `chunks`, `chunks.length`, `stream_content`,
5//! `stream_complete`, `no_chunks_after_done`, `tool_calls`, `finish_reason`.
6//! These fields resolve against the *collected* list of chunks produced by
7//! draining the stream.
8//!
9//! [`StreamingFieldResolver`] provides two entry points:
10//! - [`StreamingFieldResolver::accessor`] — the language-specific expression
11//!   for a virtual field given a local variable that holds the collected list.
12//! - [`StreamingFieldResolver::collect_snippet`] — the language-specific
13//!   code snippet that drains a stream variable into the collected list.
14//!
15//! ## Convention
16//!
17//! The `chunks_var` parameter is the local variable name that holds the
18//! collected list (default: `"chunks"`).  The `stream_var` parameter is the
19//! result variable produced by the stream call (default: `"result"`).
20//!
21//! The set of streaming-virtual field names handled by this module:
22//! - `chunks`              → the collected list itself
23//! - `chunks.length`       → length/count of the collected list
24//! - `stream_content`      → concatenation of all delta content strings
25//! - `stream_complete`     → boolean — last chunk has a non-null finish_reason
26//! - `no_chunks_after_done` → structural invariant (true by construction for
27//!   channel/iterator-based APIs once the channel is closed; emitted as
28//!   `assert!(true)` / `assertTrue` for languages without post-DONE chunk plumbing)
29//! - `tool_calls`          → flat list of tool_calls from all chunk deltas
30//! - `finish_reason`       → finish_reason string from the last chunk
31
32/// The set of field names treated as streaming-virtual fields.
33pub const STREAMING_VIRTUAL_FIELDS: &[&str] = &[
34    "chunks",
35    "chunks.length",
36    "stream_content",
37    "stream_complete",
38    "no_chunks_after_done",
39    "tool_calls",
40    "finish_reason",
41];
42
43/// Returns `true` when `field` is a streaming-virtual field name.
44pub fn is_streaming_virtual_field(field: &str) -> bool {
45    STREAMING_VIRTUAL_FIELDS.contains(&field)
46}
47
48/// Shared streaming-virtual-fields resolver for e2e test codegen.
49pub struct StreamingFieldResolver;
50
51impl StreamingFieldResolver {
52    /// Returns the language-specific expression for a streaming-virtual field,
53    /// given `chunks_var` (the collected-list local name) and `lang`.
54    ///
55    /// Returns `None` when the field name is not a known streaming-virtual
56    /// field or the language has no streaming support.
57    pub fn accessor(field: &str, lang: &str, chunks_var: &str) -> Option<String> {
58        match field {
59            "chunks" => Some(chunks_var.to_string()),
60
61            "chunks.length" => Some(match lang {
62                "rust" => format!("{chunks_var}.len()"),
63                "go" => format!("len({chunks_var})"),
64                "python" => format!("len({chunks_var})"),
65                "php" => format!("count(${chunks_var})"),
66                // node/wasm/typescript use .length
67                _ => format!("{chunks_var}.length"),
68            }),
69
70            "stream_content" => Some(match lang {
71                "rust" => {
72                    format!(
73                        "{chunks_var}.iter().map(|c| c.choices.first().and_then(|ch| ch.delta.content.as_deref()).unwrap_or(\"\")).collect::<String>()"
74                    )
75                }
76                "go" => {
77                    // Go: chunks is []pkg.ChatCompletionChunk
78                    format!(
79                        "func() string {{ var s string; for _, c := range {chunks_var} {{ if len(c.Choices) > 0 && c.Choices[0].Delta.Content != nil {{ s += *c.Choices[0].Delta.Content }} }}; return s }}()"
80                    )
81                }
82                "java" => {
83                    format!(
84                        "{chunks_var}.stream().map(c -> c.choices().stream().findFirst().map(ch -> ch.delta().content() != null ? ch.delta().content() : \"\").orElse(\"\")).collect(java.util.stream.Collectors.joining())"
85                    )
86                }
87                "php" => {
88                    format!("implode('', array_map(fn($c) => $c->choices[0]->delta->content ?? '', ${chunks_var}))")
89                }
90                "zig" => {
91                    // Zig: simplified - use empty string as fallback for zig JSON struct path
92                    format!("{chunks_var}_content")
93                }
94                // node/wasm/typescript
95                _ => {
96                    format!("{chunks_var}.map((c: any) => c.choices?.[0]?.delta?.content ?? '').join('')")
97                }
98            }),
99
100            "stream_complete" => Some(match lang {
101                "rust" => {
102                    format!(
103                        "{chunks_var}.last().and_then(|c| c.choices.first()).and_then(|ch| ch.finish_reason.as_ref()).is_some()"
104                    )
105                }
106                "go" => {
107                    format!(
108                        "func() bool {{ if len({chunks_var}) == 0 {{ return false }}; last := {chunks_var}[len({chunks_var})-1]; return len(last.Choices) > 0 && last.Choices[0].FinishReason != nil }}()"
109                    )
110                }
111                "java" => {
112                    format!(
113                        "!{chunks_var}.isEmpty() && {chunks_var}.get({chunks_var}.size()-1).choices().stream().findFirst().flatMap(ch -> java.util.Optional.ofNullable(ch.finishReason())).isPresent()"
114                    )
115                }
116                "php" => {
117                    format!("!empty(${chunks_var}) && isset(end(${chunks_var})->choices[0]->finishReason)")
118                }
119                // node/wasm/typescript
120                _ => {
121                    format!(
122                        "{chunks_var}.length > 0 && {chunks_var}[{chunks_var}.length - 1].choices?.[0]?.finishReason != null"
123                    )
124                }
125            }),
126
127            // no_chunks_after_done is a structural invariant: once the stream
128            // closes (channel drained / iterator exhausted), no further chunks
129            // can arrive.  We assert `true` as a compile-time proof of intent.
130            "no_chunks_after_done" => Some(match lang {
131                "rust" => "true".to_string(),
132                "go" => "true".to_string(),
133                "java" => "true".to_string(),
134                "php" => "true".to_string(),
135                _ => "true".to_string(),
136            }),
137
138            "tool_calls" => Some(match lang {
139                "rust" => {
140                    format!(
141                        "{chunks_var}.iter().flat_map(|c| c.choices.iter().flat_map(|ch| ch.delta.tool_calls.iter().flatten())).collect::<Vec<_>>()"
142                    )
143                }
144                "go" => {
145                    format!(
146                        "func() []interface{{}} {{ var tc []interface{{}}; for _, c := range {chunks_var} {{ for _, ch := range c.Choices {{ if ch.Delta.ToolCalls != nil {{ for _, t := range *ch.Delta.ToolCalls {{ tc = append(tc, t) }} }} }} }}; return tc }}()"
147                    )
148                }
149                "java" => {
150                    format!(
151                        "{chunks_var}.stream().flatMap(c -> c.choices().stream()).flatMap(ch -> ch.delta().toolCalls() != null ? ch.delta().toolCalls().stream() : java.util.stream.Stream.empty()).toList()"
152                    )
153                }
154                "php" => {
155                    format!(
156                        "array_merge(...array_map(fn($c) => $c->choices[0]->delta->toolCalls ?? [], ${chunks_var}))"
157                    )
158                }
159                _ => {
160                    format!("{chunks_var}.flatMap((c: any) => c.choices?.[0]?.delta?.toolCalls ?? [])")
161                }
162            }),
163
164            "finish_reason" => Some(match lang {
165                "rust" => {
166                    format!(
167                        "{chunks_var}.last().and_then(|c| c.choices.first()).and_then(|ch| ch.finish_reason.as_deref()).unwrap_or(\"\")"
168                    )
169                }
170                "go" => {
171                    format!(
172                        "func() string {{ if len({chunks_var}) == 0 {{ return \"\" }}; last := {chunks_var}[len({chunks_var})-1]; if len(last.Choices) > 0 && last.Choices[0].FinishReason != nil {{ return *last.Choices[0].FinishReason }}; return \"\" }}()"
173                    )
174                }
175                "java" => {
176                    format!(
177                        "({chunks_var}.isEmpty() ? null : {chunks_var}.get({chunks_var}.size()-1).choices().stream().findFirst().map(ch -> ch.finishReason()).orElse(null))"
178                    )
179                }
180                "php" => {
181                    format!("(!empty(${chunks_var}) ? (end(${chunks_var})->choices[0]->finishReason ?? null) : null)")
182                }
183                _ => {
184                    format!(
185                        "{chunks_var}.length > 0 ? {chunks_var}[{chunks_var}.length - 1].choices?.[0]?.finishReason : undefined"
186                    )
187                }
188            }),
189
190            _ => None,
191        }
192    }
193
194    /// Returns the language-specific stream-collect-into-list snippet that
195    /// produces `chunks_var` from `stream_var`.
196    ///
197    /// Returns `None` when the language has no streaming collect support or
198    /// when the collect snippet cannot be expressed generically.
199    pub fn collect_snippet(lang: &str, stream_var: &str, chunks_var: &str) -> Option<String> {
200        match lang {
201            "rust" => Some(format!(
202                "let {chunks_var}: Vec<_> = tokio_stream::StreamExt::collect::<Vec<_>>({stream_var}).await;"
203            )),
204            "go" => Some(format!(
205                "var {chunks_var} []pkg.ChatCompletionChunk\n\tfor chunk := range {stream_var} {{\n\t\t{chunks_var} = append({chunks_var}, chunk)\n\t}}"
206            )),
207            "java" => Some(format!(
208                "var {chunks_var} = new java.util.ArrayList<ChatCompletionChunk>();\n        var _it = {stream_var};\n        while (_it.hasNext()) {{ {chunks_var}.add(_it.next()); }}"
209            )),
210            "php" => Some(format!("${chunks_var} = iterator_to_array(${stream_var});")),
211            "node" | "wasm" | "typescript" => Some(format!(
212                "const {chunks_var}: any[] = [];\n    for await (const _chunk of {stream_var}) {{ {chunks_var}.push(_chunk); }}"
213            )),
214            "zig" => {
215                // Zig: streams are returned as opaque handles with JSON output;
216                // the collect pattern would require specialized Zig iterator
217                // drain code. Emit a simpler approach: use the result directly.
218                None
219            }
220            _ => None,
221        }
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[test]
230    fn is_streaming_virtual_field_recognizes_all_fields() {
231        for field in STREAMING_VIRTUAL_FIELDS {
232            assert!(
233                is_streaming_virtual_field(field),
234                "field '{field}' not recognized as streaming virtual"
235            );
236        }
237    }
238
239    #[test]
240    fn is_streaming_virtual_field_rejects_real_fields() {
241        assert!(!is_streaming_virtual_field("content"));
242        assert!(!is_streaming_virtual_field("choices"));
243        assert!(!is_streaming_virtual_field("model"));
244        assert!(!is_streaming_virtual_field(""));
245    }
246
247    #[test]
248    fn accessor_chunks_returns_var_name() {
249        assert_eq!(
250            StreamingFieldResolver::accessor("chunks", "rust", "chunks"),
251            Some("chunks".to_string())
252        );
253        assert_eq!(
254            StreamingFieldResolver::accessor("chunks", "node", "chunks"),
255            Some("chunks".to_string())
256        );
257    }
258
259    #[test]
260    fn accessor_chunks_length_uses_language_idiom() {
261        let rust = StreamingFieldResolver::accessor("chunks.length", "rust", "chunks").unwrap();
262        assert!(rust.contains(".len()"), "rust: {rust}");
263
264        let go = StreamingFieldResolver::accessor("chunks.length", "go", "chunks").unwrap();
265        assert!(go.starts_with("len("), "go: {go}");
266
267        let node = StreamingFieldResolver::accessor("chunks.length", "node", "chunks").unwrap();
268        assert!(node.contains(".length"), "node: {node}");
269
270        let php = StreamingFieldResolver::accessor("chunks.length", "php", "chunks").unwrap();
271        assert!(php.starts_with("count("), "php: {php}");
272    }
273
274    #[test]
275    fn accessor_stream_content_rust_uses_iterator() {
276        let expr = StreamingFieldResolver::accessor("stream_content", "rust", "chunks").unwrap();
277        assert!(expr.contains(".collect::<String>()"), "rust stream_content: {expr}");
278    }
279
280    #[test]
281    fn accessor_no_chunks_after_done_returns_true() {
282        for lang in ["rust", "go", "java", "php", "node", "wasm"] {
283            let expr = StreamingFieldResolver::accessor("no_chunks_after_done", lang, "chunks").unwrap();
284            assert_eq!(expr, "true", "lang {lang}: expected 'true', got '{expr}'");
285        }
286    }
287
288    #[test]
289    fn collect_snippet_rust_uses_tokio_stream() {
290        let snip = StreamingFieldResolver::collect_snippet("rust", "result", "chunks").unwrap();
291        assert!(snip.contains("tokio_stream::StreamExt::collect"), "rust: {snip}");
292        assert!(snip.contains("let chunks"), "rust: {snip}");
293    }
294
295    #[test]
296    fn collect_snippet_go_drains_channel() {
297        let snip = StreamingFieldResolver::collect_snippet("go", "stream", "chunks").unwrap();
298        assert!(snip.contains("for chunk := range stream"), "go: {snip}");
299    }
300
301    #[test]
302    fn collect_snippet_java_uses_iterator() {
303        let snip = StreamingFieldResolver::collect_snippet("java", "result", "chunks").unwrap();
304        assert!(snip.contains("hasNext()"), "java: {snip}");
305    }
306
307    #[test]
308    fn collect_snippet_php_uses_iterator_to_array() {
309        let snip = StreamingFieldResolver::collect_snippet("php", "result", "chunks").unwrap();
310        assert!(snip.contains("iterator_to_array"), "php: {snip}");
311    }
312
313    #[test]
314    fn collect_snippet_node_uses_for_await() {
315        let snip = StreamingFieldResolver::collect_snippet("node", "result", "chunks").unwrap();
316        assert!(snip.contains("for await"), "node: {snip}");
317    }
318
319    #[test]
320    fn accessor_unknown_field_returns_none() {
321        assert_eq!(
322            StreamingFieldResolver::accessor("nonexistent_field", "rust", "chunks"),
323            None
324        );
325    }
326}