Skip to main content

alef_e2e/codegen/rust/
assertions.rs

1//! Assertion rendering for Rust e2e tests.
2
3use std::fmt::Write as FmtWrite;
4
5use crate::escape::escape_rust;
6use crate::field_access::FieldResolver;
7use crate::fixture::Assertion;
8
9use super::assertion_helpers::{
10    render_count_equals_assertion, render_count_min_assertion, render_equals_assertion, render_gte_assertion,
11    render_is_empty_assertion, render_method_result_assertion, render_not_empty_assertion,
12};
13use super::assertion_synthetic::{
14    numeric_literal, render_chunks_have_content, render_chunks_have_embeddings, render_embedding_dimensions,
15    render_embedding_quality, render_embeddings_assertion, render_keywords_assertion, render_keywords_count_assertion,
16    tree_field_access_expr, value_to_rust_string,
17};
18
19/// Render a single assertion into the test function body.
20#[allow(clippy::too_many_arguments)]
21pub fn render_assertion(
22    out: &mut String,
23    assertion: &Assertion,
24    result_var: &str,
25    module: &str,
26    dep_name: &str,
27    is_error_context: bool,
28    unwrapped_fields: &[(String, String)], // (fixture_field, local_var)
29    field_resolver: &FieldResolver,
30    result_is_tree: bool,
31    result_is_simple: bool,
32    result_is_vec: bool,
33    result_is_option: bool,
34    returns_result: bool,
35) {
36    render_assertion_with_streaming(
37        out,
38        assertion,
39        result_var,
40        module,
41        dep_name,
42        is_error_context,
43        unwrapped_fields,
44        field_resolver,
45        result_is_tree,
46        result_is_simple,
47        result_is_vec,
48        result_is_option,
49        returns_result,
50        false,
51    )
52}
53
54/// Same as [`render_assertion`], but with an `is_streaming` flag so the streaming-virtual
55/// field arm can fire when `result_var` is the raw call result rather than the collected
56/// `chunks` variable.  Callers that already drained the stream into a `chunks: Vec<_>`
57/// local should pass `is_streaming = true`.
58#[allow(clippy::too_many_arguments)]
59pub fn render_assertion_with_streaming(
60    out: &mut String,
61    assertion: &Assertion,
62    result_var: &str,
63    module: &str,
64    dep_name: &str,
65    is_error_context: bool,
66    unwrapped_fields: &[(String, String)], // (fixture_field, local_var)
67    field_resolver: &FieldResolver,
68    result_is_tree: bool,
69    result_is_simple: bool,
70    result_is_vec: bool,
71    result_is_option: bool,
72    returns_result: bool,
73    _is_streaming: bool,
74) {
75    // Vec<T> result: iterate per-element so each assertion checks every element.
76    // Field-path assertions become `for r in &{result} { <assert using r> }`.
77    // Length-style assertions on the Vec itself (no field path) operate on the
78    // Vec directly.
79    let has_field = assertion.field.as_ref().is_some_and(|f| !f.is_empty());
80    if result_is_vec && has_field && !is_error_context {
81        let _ = writeln!(out, "    for r in &{result_var} {{");
82        render_assertion(
83            out,
84            assertion,
85            "r",
86            module,
87            dep_name,
88            is_error_context,
89            unwrapped_fields,
90            field_resolver,
91            result_is_tree,
92            result_is_simple,
93            false, // already inside loop
94            result_is_option,
95            returns_result,
96        );
97        let _ = writeln!(out, "    }}");
98        return;
99    }
100    // Option<T> result: map `is_empty`/`not_empty` to `is_none()`/`is_some()`,
101    // and unwrap the inner value before any other assertion runs.
102    if result_is_option && !is_error_context {
103        let assertion_type = assertion.assertion_type.as_str();
104        if !has_field && (assertion_type == "is_empty" || assertion_type == "not_empty") {
105            let check = if assertion_type == "is_empty" {
106                "is_none"
107            } else {
108                "is_some"
109            };
110            let _ = writeln!(
111                out,
112                "    assert!({result_var}.{check}(), \"expected Option to be {check}\");"
113            );
114            return;
115        }
116        // For any other assertion shape, unwrap the Option and recurse with a
117        // bare reference variable so the rest of the renderer treats the inner
118        // value as the result.
119        let _ = writeln!(
120            out,
121            "    let r = {result_var}.as_ref().expect(\"Option<T> should be Some\");"
122        );
123        render_assertion(
124            out,
125            assertion,
126            "r",
127            module,
128            dep_name,
129            is_error_context,
130            unwrapped_fields,
131            field_resolver,
132            result_is_tree,
133            result_is_simple,
134            result_is_vec,
135            false, // already unwrapped
136            returns_result,
137        );
138        return;
139    }
140    let _ = dep_name;
141    // Handle synthetic fields like chunks_have_content (derived assertions).
142    // These are computed expressions, not real struct fields — intercept before
143    // the is_valid_for_result check so they are never treated as field accesses.
144    if let Some(f) = &assertion.field {
145        match f.as_str() {
146            "chunks_have_content" => {
147                render_chunks_have_content(out, result_var, assertion.assertion_type.as_str());
148                return;
149            }
150            "chunks_have_embeddings" => {
151                render_chunks_have_embeddings(out, result_var, assertion.assertion_type.as_str());
152                return;
153            }
154            "embeddings" => {
155                render_embeddings_assertion(out, result_var, assertion);
156                return;
157            }
158            "embedding_dimensions" => {
159                render_embedding_dimensions(out, result_var, assertion);
160                return;
161            }
162            "embeddings_valid" | "embeddings_finite" | "embeddings_non_zero" | "embeddings_normalized" => {
163                render_embedding_quality(out, result_var, f, assertion.assertion_type.as_str());
164                return;
165            }
166            "keywords" => {
167                render_keywords_assertion(out, result_var, assertion);
168                return;
169            }
170            "keywords_count" => {
171                render_keywords_count_assertion(out, result_var, assertion);
172                return;
173            }
174            _ => {}
175        }
176    }
177
178    // Streaming virtual fields: intercept before is_valid_for_result so they are
179    // never skipped.  These fields resolve against the `chunks` collected-list variable.
180    //
181    // For streaming fixtures, `chunks` is bound by the collect snippet emitted in
182    // `render_test_function`.  For non-streaming fixtures whose result struct has a
183    // literal field whose name collides with a streaming-virtual name (e.g. `chunks`,
184    // `imports`, `structure`), `render_test_function` emits `let {f} = &result.{f};`
185    // before assertions, so the hardcoded `chunks` identifier used below still resolves.
186    if let Some(f) = &assertion.field {
187        if !f.is_empty() && crate::codegen::streaming_assertions::is_streaming_virtual_field(f) {
188            if let Some(expr) =
189                crate::codegen::streaming_assertions::StreamingFieldResolver::accessor(f, "rust", "chunks")
190            {
191                match assertion.assertion_type.as_str() {
192                    "count_min" => {
193                        if let Some(val) = &assertion.value {
194                            if let Some(n) = val.as_u64() {
195                                let _ = writeln!(
196                                    out,
197                                    "    assert!({expr}.len() >= {n} as usize, \"expected >= {n} chunks\");"
198                                );
199                            }
200                        }
201                    }
202                    "count_equals" => {
203                        if let Some(val) = &assertion.value {
204                            if let Some(n) = val.as_u64() {
205                                let _ = writeln!(
206                                    out,
207                                    "    assert_eq!({expr}.len(), {n} as usize, \"expected exactly {n} chunks\");"
208                                );
209                            }
210                        }
211                    }
212                    "equals" => {
213                        if let Some(serde_json::Value::String(s)) = &assertion.value {
214                            let escaped = crate::escape::escape_rust(s);
215                            let _ = writeln!(out, "    assert_eq!({expr}, \"{escaped}\");");
216                        } else if let Some(val) = &assertion.value {
217                            let lit = super::assertion_synthetic::numeric_literal(val);
218                            let _ = writeln!(out, "    assert_eq!({expr}, {lit});");
219                        }
220                    }
221                    "not_empty" => {
222                        let _ = writeln!(out, "    assert!(!{expr}.is_empty(), \"expected non-empty\");");
223                    }
224                    "is_empty" => {
225                        let _ = writeln!(out, "    assert!({expr}.is_empty(), \"expected empty\");");
226                    }
227                    "is_true" => {
228                        let _ = writeln!(out, "    assert!({expr}, \"expected true\");");
229                    }
230                    "is_false" => {
231                        let _ = writeln!(out, "    assert!(!{expr}, \"expected false\");");
232                    }
233                    "greater_than" => {
234                        if let Some(val) = &assertion.value {
235                            let lit = super::assertion_synthetic::numeric_literal(val);
236                            let _ = writeln!(out, "    assert!({expr} > {lit}, \"expected > {lit}\");");
237                        }
238                    }
239                    "greater_than_or_equal" => {
240                        if let Some(val) = &assertion.value {
241                            let lit = super::assertion_synthetic::numeric_literal(val);
242                            let _ = writeln!(out, "    assert!({expr} >= {lit}, \"expected >= {lit}\");");
243                        }
244                    }
245                    "contains" => {
246                        if let Some(serde_json::Value::String(s)) = &assertion.value {
247                            let escaped = crate::escape::escape_rust(s);
248                            let _ = writeln!(
249                                out,
250                                "    assert!({expr}.contains(\"{escaped}\"), \"expected to contain: {escaped}\");"
251                            );
252                        }
253                    }
254                    _ => {
255                        let _ = writeln!(
256                            out,
257                            "    // streaming field '{f}': assertion type '{}' not rendered",
258                            assertion.assertion_type
259                        );
260                    }
261                }
262            }
263            return;
264        }
265    }
266
267    // Skip assertions on fields that don't exist on the result type.
268    // Exception: fields prefixed with "error." target the error value in error-context
269    // assertions — they are resolved against the error type via accessor_for_error,
270    // not against the success result type, so they must not be skipped here.
271    // However, when NOT in error context (i.e. the call site uses .expect() and binds
272    // the Ok value), there is no Err to inspect — skip error.* assertions with a comment.
273    if let Some(f) = &assertion.field {
274        if !f.is_empty() {
275            if f.starts_with("error.") && !is_error_context {
276                let _ = writeln!(out, "    // skipped: field '{f}' not available on result type");
277                return;
278            }
279            if !f.starts_with("error.") && !field_resolver.is_valid_for_result(f) {
280                let _ = writeln!(out, "    // skipped: field '{f}' not available on result type");
281                return;
282            }
283        }
284    }
285
286    // Check if this field was unwrapped (i.e., it is optional and was bound to a local).
287    let is_unwrapped = assertion
288        .field
289        .as_ref()
290        .is_some_and(|f| unwrapped_fields.iter().any(|(ff, _)| ff == f));
291
292    // When in error context with returns_result=true and accessing a field (not an error check),
293    // we need to unwrap the Result first. The test generator creates a binding like
294    // `let result_ok = result.as_ref().ok();` which we can dereference here.
295    // Exception: fields prefixed with "error." access the Err value, not the Ok value.
296    let has_field = assertion.field.as_ref().is_some_and(|f| !f.is_empty());
297    let is_field_assertion = !matches!(assertion.assertion_type.as_str(), "error" | "not_error");
298    let is_error_field = assertion.field.as_ref().is_some_and(|f| f.starts_with("error."));
299    let effective_result_var =
300        if has_field && is_error_context && returns_result && is_field_assertion && !is_error_field {
301            // Dereference the Option<&T> bound as {result_var}_ok
302            format!("{result_var}_ok.as_ref().unwrap()")
303        } else {
304            result_var.to_string()
305        };
306
307    // Determine field access expression:
308    // 1. If the field was unwrapped to a local var, use that local var name.
309    // 2. When result_is_simple, the function returns a plain type (String etc.) — use result_var.
310    // 3. When the field path is exactly the result var name (sentinel: `field: "result"`),
311    //    refer to the result variable directly to avoid emitting `result.result`.
312    // 4. When the result is a Tree, map pseudo-field names to correct Rust expressions.
313    // 5. When the field starts with "error.", resolve against the error type.
314    // 6. Otherwise, use the field resolver to generate the accessor.
315    let field_access = match &assertion.field {
316        Some(f) if !f.is_empty() => {
317            if let Some((_, local_var)) = unwrapped_fields.iter().find(|(ff, _)| ff == f) {
318                local_var.clone()
319            } else if result_is_simple && !f.starts_with("error.") {
320                // Plain return type (String, Vec<T>, etc.) has no struct fields.
321                // Use the result variable directly so assertions operate on the value itself.
322                // Exception: error.* fields must resolve against the Err value, not the
323                // plain result variable, even when the success type is simple (e.g. Bytes).
324                effective_result_var.clone()
325            } else if f == result_var {
326                // Sentinel: fixture uses `field: "result"` (or matches the result variable name)
327                // to refer to the whole return value, not a struct field named "result".
328                effective_result_var.clone()
329            } else if result_is_tree {
330                // Tree is an opaque type — its "fields" are accessed via root_node() or
331                // free functions. Map known pseudo-field names to correct Rust expressions.
332                tree_field_access_expr(f, &effective_result_var, module)
333            } else if let Some(sub) = f.strip_prefix("error.") {
334                // Error-path field: access a field on the Err value rather than the Ok value.
335                // Inline-bind the error so the expression is self-contained.
336                let err_accessor = field_resolver.accessor_for_error(sub, "rust", "__err");
337                format!("{{ let __err = {result_var}.as_ref().err().unwrap(); {err_accessor} }}")
338            } else {
339                field_resolver.accessor(f, "rust", &effective_result_var)
340            }
341        }
342        _ => effective_result_var,
343    };
344
345    match assertion.assertion_type.as_str() {
346        "error" => {
347            let _ = writeln!(out, "    assert!({result_var}.is_err(), \"expected call to fail\");");
348            if let Some(serde_json::Value::String(msg)) = &assertion.value {
349                let escaped = escape_rust(msg);
350                // Match against the Debug format (variant-name-style) and the Display format
351                // (human-readable text). Fixtures often name the error variant ("BadRequest"),
352                // but Display impls typically lowercase with a colon ("bad request: ..."), so
353                // checking both lets either kind of fixture value match.
354                let _ = writeln!(
355                    out,
356                    "    {{ let __e = {result_var}.as_ref().err().unwrap(); assert!(format!(\"{{:?}}\", __e).contains(\"{escaped}\") || __e.to_string().contains(\"{escaped}\"), \"error message mismatch\"); }}"
357                );
358            }
359        }
360        "not_error" => {
361            // Handled at call site; nothing extra needed here.
362        }
363        "equals" => {
364            render_equals_assertion(out, assertion, &field_access, is_unwrapped, field_resolver);
365        }
366        "contains" => {
367            if let Some(val) = &assertion.value {
368                let expected = value_to_rust_string(val);
369                let line = format!(
370                    "    assert!(format!(\"{{:?}}\", {field_access}).contains({expected}), \"expected to contain: {{}}\", {expected});"
371                );
372                let _ = writeln!(out, "{line}");
373            }
374        }
375        "contains_all" => {
376            if let Some(values) = &assertion.values {
377                for val in values {
378                    let expected = value_to_rust_string(val);
379                    let line = format!(
380                        "    assert!(format!(\"{{:?}}\", {field_access}).contains({expected}), \"expected to contain: {{}}\", {expected});"
381                    );
382                    let _ = writeln!(out, "{line}");
383                }
384            }
385        }
386        "not_contains" => {
387            if let Some(val) = &assertion.value {
388                let expected = value_to_rust_string(val);
389                let line = format!(
390                    "    assert!(!format!(\"{{:?}}\", {field_access}).contains({expected}), \"expected NOT to contain: {{}}\", {expected});"
391                );
392                let _ = writeln!(out, "{line}");
393            }
394        }
395        "not_empty" => {
396            render_not_empty_assertion(
397                out,
398                assertion,
399                &field_access,
400                result_var,
401                result_is_option,
402                is_unwrapped,
403                field_resolver,
404            );
405        }
406        "is_empty" => {
407            render_is_empty_assertion(out, assertion, &field_access, is_unwrapped, field_resolver);
408        }
409        "contains_any" => {
410            if let Some(values) = &assertion.values {
411                let checks: Vec<String> = values
412                    .iter()
413                    .map(|v| {
414                        let expected = value_to_rust_string(v);
415                        format!("{field_access}.contains({expected})")
416                    })
417                    .collect();
418                let joined = checks.join(" || ");
419                let _ = writeln!(
420                    out,
421                    "    assert!({joined}, \"expected to contain at least one of the specified values\");"
422                );
423            }
424        }
425        "greater_than" => {
426            if let Some(val) = &assertion.value {
427                // Skip comparisons with negative values against unsigned types (.len() etc.)
428                if val.as_f64().is_some_and(|n| n < 0.0) {
429                    let _ = writeln!(
430                        out,
431                        "    // skipped: greater_than with negative value is always true for unsigned types"
432                    );
433                } else if val.as_u64() == Some(0) {
434                    if field_access.ends_with(".len()") {
435                        // Clippy prefers !is_empty() over len() > 0 for collections.
436                        let base = field_access.strip_suffix(".len()").unwrap();
437                        let _ = writeln!(out, "    assert!(!{base}.is_empty(), \"expected > 0\");");
438                    } else {
439                        // Scalar types (usize, u64, etc.) — use direct comparison.
440                        let _ = writeln!(out, "    assert!({field_access} > 0, \"expected > 0\");");
441                    }
442                } else {
443                    let lit = numeric_literal(val);
444                    let _ = writeln!(out, "    assert!({field_access} > {lit}, \"expected > {lit}\");");
445                }
446            }
447        }
448        "less_than" => {
449            if let Some(val) = &assertion.value {
450                let lit = numeric_literal(val);
451                let _ = writeln!(out, "    assert!({field_access} < {lit}, \"expected < {lit}\");");
452            }
453        }
454        "greater_than_or_equal" => {
455            render_gte_assertion(out, assertion, &field_access, is_unwrapped, field_resolver);
456        }
457        "less_than_or_equal" => {
458            if let Some(val) = &assertion.value {
459                let lit = numeric_literal(val);
460                let _ = writeln!(out, "    assert!({field_access} <= {lit}, \"expected <= {lit}\");");
461            }
462        }
463        "starts_with" => {
464            if let Some(val) = &assertion.value {
465                let expected = value_to_rust_string(val);
466                let _ = writeln!(
467                    out,
468                    "    assert!({field_access}.starts_with({expected}), \"expected to start with: {{}}\", {expected});"
469                );
470            }
471        }
472        "ends_with" => {
473            if let Some(val) = &assertion.value {
474                let expected = value_to_rust_string(val);
475                let _ = writeln!(
476                    out,
477                    "    assert!({field_access}.ends_with({expected}), \"expected to end with: {{}}\", {expected});"
478                );
479            }
480        }
481        "min_length" => {
482            if let Some(val) = &assertion.value {
483                if let Some(n) = val.as_u64() {
484                    let _ = writeln!(
485                        out,
486                        "    assert!({field_access}.len() >= {n}, \"expected length >= {n}, got {{}}\", {field_access}.len());"
487                    );
488                }
489            }
490        }
491        "max_length" => {
492            if let Some(val) = &assertion.value {
493                if let Some(n) = val.as_u64() {
494                    let _ = writeln!(
495                        out,
496                        "    assert!({field_access}.len() <= {n}, \"expected length <= {n}, got {{}}\", {field_access}.len());"
497                    );
498                }
499            }
500        }
501        "count_min" => {
502            render_count_min_assertion(out, assertion, &field_access, is_unwrapped, field_resolver);
503        }
504        "count_equals" => {
505            render_count_equals_assertion(out, assertion, &field_access, is_unwrapped, field_resolver);
506        }
507        "is_true" => {
508            let _ = writeln!(out, "    assert!({field_access}, \"expected true\");");
509        }
510        "is_false" => {
511            let _ = writeln!(out, "    assert!(!{field_access}, \"expected false\");");
512        }
513        "method_result" => {
514            render_method_result_assertion(out, assertion, &field_access, result_is_tree, module);
515        }
516        other => {
517            panic!("Rust e2e generator: unsupported assertion type: {other}");
518        }
519    }
520}
521
522#[cfg(test)]
523mod tests {
524    use std::collections::{HashMap, HashSet};
525
526    use super::*;
527    use crate::field_access::FieldResolver;
528    use crate::fixture::Assertion;
529
530    fn empty_resolver() -> FieldResolver {
531        FieldResolver::new(
532            &HashMap::new(),
533            &HashSet::new(),
534            &HashSet::new(),
535            &HashSet::new(),
536            &HashSet::new(),
537        )
538    }
539
540    fn make_assertion(assertion_type: &str, field: Option<&str>, value: Option<serde_json::Value>) -> Assertion {
541        Assertion {
542            assertion_type: assertion_type.to_string(),
543            field: field.map(|s| s.to_string()),
544            value,
545            ..Default::default()
546        }
547    }
548
549    #[test]
550    fn render_assertion_error_type_emits_is_err_check() {
551        let resolver = empty_resolver();
552        let assertion = make_assertion("error", None, None);
553        let mut out = String::new();
554        render_assertion(
555            &mut out,
556            &assertion,
557            "result",
558            "my_mod",
559            "dep",
560            true,
561            &[],
562            &resolver,
563            false,
564            false,
565            false,
566            false,
567            false,
568        );
569        assert!(out.contains("is_err()"), "got: {out}");
570    }
571
572    #[test]
573    fn render_assertion_vec_result_wraps_in_for_loop() {
574        let resolver = empty_resolver();
575        let assertion = make_assertion("not_empty", Some("content"), None);
576        let mut out = String::new();
577        render_assertion(
578            &mut out,
579            &assertion,
580            "result",
581            "my_mod",
582            "dep",
583            false,
584            &[],
585            &resolver,
586            false,
587            false,
588            true,
589            false,
590            false,
591        );
592        assert!(out.contains("for r in"), "got: {out}");
593    }
594
595    #[test]
596    fn render_assertion_not_empty_bare_result_uses_is_empty() {
597        let resolver = empty_resolver();
598        let assertion = make_assertion("not_empty", None, None);
599        let mut out = String::new();
600        render_assertion(
601            &mut out,
602            &assertion,
603            "result",
604            "my_mod",
605            "dep",
606            false,
607            &[],
608            &resolver,
609            false,
610            false,
611            false,
612            false,
613            false,
614        );
615        assert!(out.contains("is_empty()"), "got: {out}");
616    }
617}