Skip to main content

alef_e2e/codegen/rust/
assertions.rs

1//! Assertion rendering for Rust e2e tests.
2
3use std::fmt::Write as FmtWrite;
4
5use crate::escape::escape_rust;
6use crate::field_access::FieldResolver;
7use crate::fixture::Assertion;
8
9use super::assertion_helpers::{
10    render_count_equals_assertion, render_count_min_assertion, render_equals_assertion, render_gte_assertion,
11    render_is_empty_assertion, render_method_result_assertion, render_not_empty_assertion,
12};
13use super::assertion_synthetic::{
14    numeric_literal, render_chunks_have_content, render_chunks_have_embeddings, render_embedding_dimensions,
15    render_embedding_quality, render_embeddings_assertion, render_keywords_assertion, render_keywords_count_assertion,
16    tree_field_access_expr, value_to_rust_string,
17};
18
19/// Render a single assertion into the test function body.
20#[allow(clippy::too_many_arguments)]
21pub fn render_assertion(
22    out: &mut String,
23    assertion: &Assertion,
24    result_var: &str,
25    module: &str,
26    dep_name: &str,
27    is_error_context: bool,
28    unwrapped_fields: &[(String, String)], // (fixture_field, local_var)
29    field_resolver: &FieldResolver,
30    result_is_tree: bool,
31    result_is_simple: bool,
32    result_is_vec: bool,
33    result_is_option: bool,
34    returns_result: bool,
35) {
36    // Vec<T> result: iterate per-element so each assertion checks every element.
37    // Field-path assertions become `for r in &{result} { <assert using r> }`.
38    // Length-style assertions on the Vec itself (no field path) operate on the
39    // Vec directly.
40    let has_field = assertion.field.as_ref().is_some_and(|f| !f.is_empty());
41    if result_is_vec && has_field && !is_error_context {
42        let _ = writeln!(out, "    for r in &{result_var} {{");
43        render_assertion(
44            out,
45            assertion,
46            "r",
47            module,
48            dep_name,
49            is_error_context,
50            unwrapped_fields,
51            field_resolver,
52            result_is_tree,
53            result_is_simple,
54            false, // already inside loop
55            result_is_option,
56            returns_result,
57        );
58        let _ = writeln!(out, "    }}");
59        return;
60    }
61    // Option<T> result: map `is_empty`/`not_empty` to `is_none()`/`is_some()`,
62    // and unwrap the inner value before any other assertion runs.
63    if result_is_option && !is_error_context {
64        let assertion_type = assertion.assertion_type.as_str();
65        if !has_field && (assertion_type == "is_empty" || assertion_type == "not_empty") {
66            let check = if assertion_type == "is_empty" {
67                "is_none"
68            } else {
69                "is_some"
70            };
71            let _ = writeln!(
72                out,
73                "    assert!({result_var}.{check}(), \"expected Option to be {check}\");"
74            );
75            return;
76        }
77        // For any other assertion shape, unwrap the Option and recurse with a
78        // bare reference variable so the rest of the renderer treats the inner
79        // value as the result.
80        let _ = writeln!(
81            out,
82            "    let r = {result_var}.as_ref().expect(\"Option<T> should be Some\");"
83        );
84        render_assertion(
85            out,
86            assertion,
87            "r",
88            module,
89            dep_name,
90            is_error_context,
91            unwrapped_fields,
92            field_resolver,
93            result_is_tree,
94            result_is_simple,
95            result_is_vec,
96            false, // already unwrapped
97            returns_result,
98        );
99        return;
100    }
101    let _ = dep_name;
102    // Handle synthetic fields like chunks_have_content (derived assertions).
103    // These are computed expressions, not real struct fields — intercept before
104    // the is_valid_for_result check so they are never treated as field accesses.
105    if let Some(f) = &assertion.field {
106        match f.as_str() {
107            "chunks_have_content" => {
108                render_chunks_have_content(out, result_var, assertion.assertion_type.as_str());
109                return;
110            }
111            "chunks_have_embeddings" => {
112                render_chunks_have_embeddings(out, result_var, assertion.assertion_type.as_str());
113                return;
114            }
115            "embeddings" => {
116                render_embeddings_assertion(out, result_var, assertion);
117                return;
118            }
119            "embedding_dimensions" => {
120                render_embedding_dimensions(out, result_var, assertion);
121                return;
122            }
123            "embeddings_valid" | "embeddings_finite" | "embeddings_non_zero" | "embeddings_normalized" => {
124                render_embedding_quality(out, result_var, f, assertion.assertion_type.as_str());
125                return;
126            }
127            "keywords" => {
128                render_keywords_assertion(out, result_var, assertion);
129                return;
130            }
131            "keywords_count" => {
132                render_keywords_count_assertion(out, result_var, assertion);
133                return;
134            }
135            _ => {}
136        }
137    }
138
139    // Streaming virtual fields: intercept before is_valid_for_result so they are
140    // never skipped.  These fields resolve against the `chunks` collected-list variable.
141    // Gate on `result_var == "chunks"` so non-streaming tests asserting on ambiguous
142    // fields like `usage.total_tokens` don't accidentally reach for an undefined chunks
143    // var; the streaming codegen always names the collected list `chunks`.
144    if let Some(f) = &assertion.field {
145        if result_var == "chunks"
146            && !f.is_empty()
147            && crate::codegen::streaming_assertions::is_streaming_virtual_field(f)
148        {
149            if let Some(expr) =
150                crate::codegen::streaming_assertions::StreamingFieldResolver::accessor(f, "rust", "chunks")
151            {
152                match assertion.assertion_type.as_str() {
153                    "count_min" => {
154                        if let Some(val) = &assertion.value {
155                            if let Some(n) = val.as_u64() {
156                                let _ = writeln!(
157                                    out,
158                                    "    assert!({expr}.len() >= {n} as usize, \"expected >= {n} chunks\");"
159                                );
160                            }
161                        }
162                    }
163                    "count_equals" => {
164                        if let Some(val) = &assertion.value {
165                            if let Some(n) = val.as_u64() {
166                                let _ = writeln!(
167                                    out,
168                                    "    assert_eq!({expr}.len(), {n} as usize, \"expected exactly {n} chunks\");"
169                                );
170                            }
171                        }
172                    }
173                    "equals" => {
174                        if let Some(serde_json::Value::String(s)) = &assertion.value {
175                            let escaped = crate::escape::escape_rust(s);
176                            let _ = writeln!(out, "    assert_eq!({expr}, \"{escaped}\");");
177                        } else if let Some(val) = &assertion.value {
178                            let lit = super::assertion_synthetic::numeric_literal(val);
179                            let _ = writeln!(out, "    assert_eq!({expr}, {lit});");
180                        }
181                    }
182                    "not_empty" => {
183                        let _ = writeln!(out, "    assert!(!{expr}.is_empty(), \"expected non-empty\");");
184                    }
185                    "is_empty" => {
186                        let _ = writeln!(out, "    assert!({expr}.is_empty(), \"expected empty\");");
187                    }
188                    "is_true" => {
189                        let _ = writeln!(out, "    assert!({expr}, \"expected true\");");
190                    }
191                    "is_false" => {
192                        let _ = writeln!(out, "    assert!(!{expr}, \"expected false\");");
193                    }
194                    "greater_than" => {
195                        if let Some(val) = &assertion.value {
196                            let lit = super::assertion_synthetic::numeric_literal(val);
197                            let _ = writeln!(out, "    assert!({expr} > {lit}, \"expected > {lit}\");");
198                        }
199                    }
200                    "greater_than_or_equal" => {
201                        if let Some(val) = &assertion.value {
202                            let lit = super::assertion_synthetic::numeric_literal(val);
203                            let _ = writeln!(out, "    assert!({expr} >= {lit}, \"expected >= {lit}\");");
204                        }
205                    }
206                    "contains" => {
207                        if let Some(serde_json::Value::String(s)) = &assertion.value {
208                            let escaped = crate::escape::escape_rust(s);
209                            let _ = writeln!(
210                                out,
211                                "    assert!({expr}.contains(\"{escaped}\"), \"expected to contain: {escaped}\");"
212                            );
213                        }
214                    }
215                    _ => {
216                        let _ = writeln!(
217                            out,
218                            "    // streaming field '{f}': assertion type '{}' not rendered",
219                            assertion.assertion_type
220                        );
221                    }
222                }
223            }
224            return;
225        }
226    }
227
228    // Skip assertions on fields that don't exist on the result type.
229    // Exception: fields prefixed with "error." target the error value in error-context
230    // assertions — they are resolved against the error type via accessor_for_error,
231    // not against the success result type, so they must not be skipped here.
232    // However, when NOT in error context (i.e. the call site uses .expect() and binds
233    // the Ok value), there is no Err to inspect — skip error.* assertions with a comment.
234    if let Some(f) = &assertion.field {
235        if !f.is_empty() {
236            if f.starts_with("error.") && !is_error_context {
237                let _ = writeln!(out, "    // skipped: field '{f}' not available on result type");
238                return;
239            }
240            if !f.starts_with("error.") && !field_resolver.is_valid_for_result(f) {
241                let _ = writeln!(out, "    // skipped: field '{f}' not available on result type");
242                return;
243            }
244        }
245    }
246
247    // Check if this field was unwrapped (i.e., it is optional and was bound to a local).
248    let is_unwrapped = assertion
249        .field
250        .as_ref()
251        .is_some_and(|f| unwrapped_fields.iter().any(|(ff, _)| ff == f));
252
253    // When in error context with returns_result=true and accessing a field (not an error check),
254    // we need to unwrap the Result first. The test generator creates a binding like
255    // `let result_ok = result.as_ref().ok();` which we can dereference here.
256    // Exception: fields prefixed with "error." access the Err value, not the Ok value.
257    let has_field = assertion.field.as_ref().is_some_and(|f| !f.is_empty());
258    let is_field_assertion = !matches!(assertion.assertion_type.as_str(), "error" | "not_error");
259    let is_error_field = assertion.field.as_ref().is_some_and(|f| f.starts_with("error."));
260    let effective_result_var =
261        if has_field && is_error_context && returns_result && is_field_assertion && !is_error_field {
262            // Dereference the Option<&T> bound as {result_var}_ok
263            format!("{result_var}_ok.as_ref().unwrap()")
264        } else {
265            result_var.to_string()
266        };
267
268    // Determine field access expression:
269    // 1. If the field was unwrapped to a local var, use that local var name.
270    // 2. When result_is_simple, the function returns a plain type (String etc.) — use result_var.
271    // 3. When the field path is exactly the result var name (sentinel: `field: "result"`),
272    //    refer to the result variable directly to avoid emitting `result.result`.
273    // 4. When the result is a Tree, map pseudo-field names to correct Rust expressions.
274    // 5. When the field starts with "error.", resolve against the error type.
275    // 6. Otherwise, use the field resolver to generate the accessor.
276    let field_access = match &assertion.field {
277        Some(f) if !f.is_empty() => {
278            if let Some((_, local_var)) = unwrapped_fields.iter().find(|(ff, _)| ff == f) {
279                local_var.clone()
280            } else if result_is_simple && !f.starts_with("error.") {
281                // Plain return type (String, Vec<T>, etc.) has no struct fields.
282                // Use the result variable directly so assertions operate on the value itself.
283                // Exception: error.* fields must resolve against the Err value, not the
284                // plain result variable, even when the success type is simple (e.g. Bytes).
285                effective_result_var.clone()
286            } else if f == result_var {
287                // Sentinel: fixture uses `field: "result"` (or matches the result variable name)
288                // to refer to the whole return value, not a struct field named "result".
289                effective_result_var.clone()
290            } else if result_is_tree {
291                // Tree is an opaque type — its "fields" are accessed via root_node() or
292                // free functions. Map known pseudo-field names to correct Rust expressions.
293                tree_field_access_expr(f, &effective_result_var, module)
294            } else if let Some(sub) = f.strip_prefix("error.") {
295                // Error-path field: access a field on the Err value rather than the Ok value.
296                // Inline-bind the error so the expression is self-contained.
297                let err_accessor = field_resolver.accessor_for_error(sub, "rust", "__err");
298                format!("{{ let __err = {result_var}.as_ref().err().unwrap(); {err_accessor} }}")
299            } else {
300                field_resolver.accessor(f, "rust", &effective_result_var)
301            }
302        }
303        _ => effective_result_var,
304    };
305
306    match assertion.assertion_type.as_str() {
307        "error" => {
308            let _ = writeln!(out, "    assert!({result_var}.is_err(), \"expected call to fail\");");
309            if let Some(serde_json::Value::String(msg)) = &assertion.value {
310                let escaped = escape_rust(msg);
311                // Match against the Debug format (variant-name-style) and the Display format
312                // (human-readable text). Fixtures often name the error variant ("BadRequest"),
313                // but Display impls typically lowercase with a colon ("bad request: ..."), so
314                // checking both lets either kind of fixture value match.
315                let _ = writeln!(
316                    out,
317                    "    {{ let __e = {result_var}.as_ref().err().unwrap(); assert!(format!(\"{{:?}}\", __e).contains(\"{escaped}\") || __e.to_string().contains(\"{escaped}\"), \"error message mismatch\"); }}"
318                );
319            }
320        }
321        "not_error" => {
322            // Handled at call site; nothing extra needed here.
323        }
324        "equals" => {
325            render_equals_assertion(out, assertion, &field_access, is_unwrapped, field_resolver);
326        }
327        "contains" => {
328            if let Some(val) = &assertion.value {
329                let expected = value_to_rust_string(val);
330                let line = format!(
331                    "    assert!(format!(\"{{:?}}\", {field_access}).contains({expected}), \"expected to contain: {{}}\", {expected});"
332                );
333                let _ = writeln!(out, "{line}");
334            }
335        }
336        "contains_all" => {
337            if let Some(values) = &assertion.values {
338                for val in values {
339                    let expected = value_to_rust_string(val);
340                    let line = format!(
341                        "    assert!(format!(\"{{:?}}\", {field_access}).contains({expected}), \"expected to contain: {{}}\", {expected});"
342                    );
343                    let _ = writeln!(out, "{line}");
344                }
345            }
346        }
347        "not_contains" => {
348            if let Some(val) = &assertion.value {
349                let expected = value_to_rust_string(val);
350                let line = format!(
351                    "    assert!(!format!(\"{{:?}}\", {field_access}).contains({expected}), \"expected NOT to contain: {{}}\", {expected});"
352                );
353                let _ = writeln!(out, "{line}");
354            }
355        }
356        "not_empty" => {
357            render_not_empty_assertion(
358                out,
359                assertion,
360                &field_access,
361                result_var,
362                result_is_option,
363                is_unwrapped,
364                field_resolver,
365            );
366        }
367        "is_empty" => {
368            render_is_empty_assertion(out, assertion, &field_access, is_unwrapped, field_resolver);
369        }
370        "contains_any" => {
371            if let Some(values) = &assertion.values {
372                let checks: Vec<String> = values
373                    .iter()
374                    .map(|v| {
375                        let expected = value_to_rust_string(v);
376                        format!("{field_access}.contains({expected})")
377                    })
378                    .collect();
379                let joined = checks.join(" || ");
380                let _ = writeln!(
381                    out,
382                    "    assert!({joined}, \"expected to contain at least one of the specified values\");"
383                );
384            }
385        }
386        "greater_than" => {
387            if let Some(val) = &assertion.value {
388                // Skip comparisons with negative values against unsigned types (.len() etc.)
389                if val.as_f64().is_some_and(|n| n < 0.0) {
390                    let _ = writeln!(
391                        out,
392                        "    // skipped: greater_than with negative value is always true for unsigned types"
393                    );
394                } else if val.as_u64() == Some(0) {
395                    if field_access.ends_with(".len()") {
396                        // Clippy prefers !is_empty() over len() > 0 for collections.
397                        let base = field_access.strip_suffix(".len()").unwrap();
398                        let _ = writeln!(out, "    assert!(!{base}.is_empty(), \"expected > 0\");");
399                    } else {
400                        // Scalar types (usize, u64, etc.) — use direct comparison.
401                        let _ = writeln!(out, "    assert!({field_access} > 0, \"expected > 0\");");
402                    }
403                } else {
404                    let lit = numeric_literal(val);
405                    let _ = writeln!(out, "    assert!({field_access} > {lit}, \"expected > {lit}\");");
406                }
407            }
408        }
409        "less_than" => {
410            if let Some(val) = &assertion.value {
411                let lit = numeric_literal(val);
412                let _ = writeln!(out, "    assert!({field_access} < {lit}, \"expected < {lit}\");");
413            }
414        }
415        "greater_than_or_equal" => {
416            render_gte_assertion(out, assertion, &field_access, is_unwrapped, field_resolver);
417        }
418        "less_than_or_equal" => {
419            if let Some(val) = &assertion.value {
420                let lit = numeric_literal(val);
421                let _ = writeln!(out, "    assert!({field_access} <= {lit}, \"expected <= {lit}\");");
422            }
423        }
424        "starts_with" => {
425            if let Some(val) = &assertion.value {
426                let expected = value_to_rust_string(val);
427                let _ = writeln!(
428                    out,
429                    "    assert!({field_access}.starts_with({expected}), \"expected to start with: {{}}\", {expected});"
430                );
431            }
432        }
433        "ends_with" => {
434            if let Some(val) = &assertion.value {
435                let expected = value_to_rust_string(val);
436                let _ = writeln!(
437                    out,
438                    "    assert!({field_access}.ends_with({expected}), \"expected to end with: {{}}\", {expected});"
439                );
440            }
441        }
442        "min_length" => {
443            if let Some(val) = &assertion.value {
444                if let Some(n) = val.as_u64() {
445                    let _ = writeln!(
446                        out,
447                        "    assert!({field_access}.len() >= {n}, \"expected length >= {n}, got {{}}\", {field_access}.len());"
448                    );
449                }
450            }
451        }
452        "max_length" => {
453            if let Some(val) = &assertion.value {
454                if let Some(n) = val.as_u64() {
455                    let _ = writeln!(
456                        out,
457                        "    assert!({field_access}.len() <= {n}, \"expected length <= {n}, got {{}}\", {field_access}.len());"
458                    );
459                }
460            }
461        }
462        "count_min" => {
463            render_count_min_assertion(out, assertion, &field_access, is_unwrapped, field_resolver);
464        }
465        "count_equals" => {
466            render_count_equals_assertion(out, assertion, &field_access, is_unwrapped, field_resolver);
467        }
468        "is_true" => {
469            let _ = writeln!(out, "    assert!({field_access}, \"expected true\");");
470        }
471        "is_false" => {
472            let _ = writeln!(out, "    assert!(!{field_access}, \"expected false\");");
473        }
474        "method_result" => {
475            render_method_result_assertion(out, assertion, &field_access, result_is_tree, module);
476        }
477        other => {
478            panic!("Rust e2e generator: unsupported assertion type: {other}");
479        }
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use std::collections::{HashMap, HashSet};
486
487    use super::*;
488    use crate::field_access::FieldResolver;
489    use crate::fixture::Assertion;
490
491    fn empty_resolver() -> FieldResolver {
492        FieldResolver::new(
493            &HashMap::new(),
494            &HashSet::new(),
495            &HashSet::new(),
496            &HashSet::new(),
497            &HashSet::new(),
498        )
499    }
500
501    fn make_assertion(assertion_type: &str, field: Option<&str>, value: Option<serde_json::Value>) -> Assertion {
502        Assertion {
503            assertion_type: assertion_type.to_string(),
504            field: field.map(|s| s.to_string()),
505            value,
506            ..Default::default()
507        }
508    }
509
510    #[test]
511    fn render_assertion_error_type_emits_is_err_check() {
512        let resolver = empty_resolver();
513        let assertion = make_assertion("error", None, None);
514        let mut out = String::new();
515        render_assertion(
516            &mut out,
517            &assertion,
518            "result",
519            "my_mod",
520            "dep",
521            true,
522            &[],
523            &resolver,
524            false,
525            false,
526            false,
527            false,
528            false,
529        );
530        assert!(out.contains("is_err()"), "got: {out}");
531    }
532
533    #[test]
534    fn render_assertion_vec_result_wraps_in_for_loop() {
535        let resolver = empty_resolver();
536        let assertion = make_assertion("not_empty", Some("content"), None);
537        let mut out = String::new();
538        render_assertion(
539            &mut out,
540            &assertion,
541            "result",
542            "my_mod",
543            "dep",
544            false,
545            &[],
546            &resolver,
547            false,
548            false,
549            true,
550            false,
551            false,
552        );
553        assert!(out.contains("for r in"), "got: {out}");
554    }
555
556    #[test]
557    fn render_assertion_not_empty_bare_result_uses_is_empty() {
558        let resolver = empty_resolver();
559        let assertion = make_assertion("not_empty", None, None);
560        let mut out = String::new();
561        render_assertion(
562            &mut out,
563            &assertion,
564            "result",
565            "my_mod",
566            "dep",
567            false,
568            &[],
569            &resolver,
570            false,
571            false,
572            false,
573            false,
574            false,
575        );
576        assert!(out.contains("is_empty()"), "got: {out}");
577    }
578}