Skip to main content

alef_e2e/codegen/rust/
assertions.rs

1//! Assertion rendering for Rust e2e tests.
2
3use std::fmt::Write as FmtWrite;
4
5use crate::escape::escape_rust;
6use crate::field_access::FieldResolver;
7use crate::fixture::Assertion;
8
9use super::assertion_helpers::{
10    render_count_equals_assertion, render_count_min_assertion, render_equals_assertion, render_gte_assertion,
11    render_is_empty_assertion, render_method_result_assertion, render_not_empty_assertion,
12};
13use super::assertion_synthetic::{
14    numeric_literal, render_chunks_have_content, render_chunks_have_embeddings, render_embedding_dimensions,
15    render_embedding_quality, render_embeddings_assertion, render_keywords_assertion, render_keywords_count_assertion,
16    tree_field_access_expr, value_to_rust_string,
17};
18
19/// Render a single assertion into the test function body.
20#[allow(clippy::too_many_arguments)]
21pub fn render_assertion(
22    out: &mut String,
23    assertion: &Assertion,
24    result_var: &str,
25    module: &str,
26    dep_name: &str,
27    is_error_context: bool,
28    unwrapped_fields: &[(String, String)], // (fixture_field, local_var)
29    field_resolver: &FieldResolver,
30    result_is_tree: bool,
31    result_is_simple: bool,
32    result_is_vec: bool,
33    result_is_option: bool,
34    returns_result: bool,
35) {
36    // Vec<T> result: iterate per-element so each assertion checks every element.
37    // Field-path assertions become `for r in &{result} { <assert using r> }`.
38    // Length-style assertions on the Vec itself (no field path) operate on the
39    // Vec directly.
40    let has_field = assertion.field.as_ref().is_some_and(|f| !f.is_empty());
41    if result_is_vec && has_field && !is_error_context {
42        let _ = writeln!(out, "    for r in &{result_var} {{");
43        render_assertion(
44            out,
45            assertion,
46            "r",
47            module,
48            dep_name,
49            is_error_context,
50            unwrapped_fields,
51            field_resolver,
52            result_is_tree,
53            result_is_simple,
54            false, // already inside loop
55            result_is_option,
56            returns_result,
57        );
58        let _ = writeln!(out, "    }}");
59        return;
60    }
61    // Option<T> result: map `is_empty`/`not_empty` to `is_none()`/`is_some()`,
62    // and unwrap the inner value before any other assertion runs.
63    if result_is_option && !is_error_context {
64        let assertion_type = assertion.assertion_type.as_str();
65        if !has_field && (assertion_type == "is_empty" || assertion_type == "not_empty") {
66            let check = if assertion_type == "is_empty" {
67                "is_none"
68            } else {
69                "is_some"
70            };
71            let _ = writeln!(
72                out,
73                "    assert!({result_var}.{check}(), \"expected Option to be {check}\");"
74            );
75            return;
76        }
77        // For any other assertion shape, unwrap the Option and recurse with a
78        // bare reference variable so the rest of the renderer treats the inner
79        // value as the result.
80        let _ = writeln!(
81            out,
82            "    let r = {result_var}.as_ref().expect(\"Option<T> should be Some\");"
83        );
84        render_assertion(
85            out,
86            assertion,
87            "r",
88            module,
89            dep_name,
90            is_error_context,
91            unwrapped_fields,
92            field_resolver,
93            result_is_tree,
94            result_is_simple,
95            result_is_vec,
96            false, // already unwrapped
97            returns_result,
98        );
99        return;
100    }
101    let _ = dep_name;
102    // Handle synthetic fields like chunks_have_content (derived assertions).
103    // These are computed expressions, not real struct fields — intercept before
104    // the is_valid_for_result check so they are never treated as field accesses.
105    if let Some(f) = &assertion.field {
106        match f.as_str() {
107            "chunks_have_content" => {
108                render_chunks_have_content(out, result_var, assertion.assertion_type.as_str());
109                return;
110            }
111            "chunks_have_embeddings" => {
112                render_chunks_have_embeddings(out, result_var, assertion.assertion_type.as_str());
113                return;
114            }
115            "embeddings" => {
116                render_embeddings_assertion(out, result_var, assertion);
117                return;
118            }
119            "embedding_dimensions" => {
120                render_embedding_dimensions(out, result_var, assertion);
121                return;
122            }
123            "embeddings_valid" | "embeddings_finite" | "embeddings_non_zero" | "embeddings_normalized" => {
124                render_embedding_quality(out, result_var, f, assertion.assertion_type.as_str());
125                return;
126            }
127            "keywords" => {
128                render_keywords_assertion(out, result_var, assertion);
129                return;
130            }
131            "keywords_count" => {
132                render_keywords_count_assertion(out, result_var, assertion);
133                return;
134            }
135            _ => {}
136        }
137    }
138
139    // Streaming virtual fields: intercept before is_valid_for_result so they are
140    // never skipped.  These fields resolve against the `chunks` collected-list variable.
141    if let Some(f) = &assertion.field {
142        if !f.is_empty() && crate::codegen::streaming_assertions::is_streaming_virtual_field(f) {
143            if let Some(expr) =
144                crate::codegen::streaming_assertions::StreamingFieldResolver::accessor(f, "rust", "chunks")
145            {
146                match assertion.assertion_type.as_str() {
147                    "count_min" => {
148                        if let Some(val) = &assertion.value {
149                            if let Some(n) = val.as_u64() {
150                                let _ = writeln!(
151                                    out,
152                                    "    assert!({expr}.len() >= {n} as usize, \"expected >= {n} chunks\");"
153                                );
154                            }
155                        }
156                    }
157                    "count_equals" => {
158                        if let Some(val) = &assertion.value {
159                            if let Some(n) = val.as_u64() {
160                                let _ = writeln!(
161                                    out,
162                                    "    assert_eq!({expr}.len(), {n} as usize, \"expected exactly {n} chunks\");"
163                                );
164                            }
165                        }
166                    }
167                    "equals" => {
168                        if let Some(serde_json::Value::String(s)) = &assertion.value {
169                            let escaped = crate::escape::escape_rust(s);
170                            let _ = writeln!(out, "    assert_eq!({expr}, \"{escaped}\");");
171                        } else if let Some(val) = &assertion.value {
172                            let lit = super::assertion_synthetic::numeric_literal(val);
173                            let _ = writeln!(out, "    assert_eq!({expr}, {lit});");
174                        }
175                    }
176                    "not_empty" => {
177                        let _ = writeln!(out, "    assert!(!{expr}.is_empty(), \"expected non-empty\");");
178                    }
179                    "is_empty" => {
180                        let _ = writeln!(out, "    assert!({expr}.is_empty(), \"expected empty\");");
181                    }
182                    "is_true" => {
183                        let _ = writeln!(out, "    assert!({expr}, \"expected true\");");
184                    }
185                    "is_false" => {
186                        let _ = writeln!(out, "    assert!(!{expr}, \"expected false\");");
187                    }
188                    "greater_than" => {
189                        if let Some(val) = &assertion.value {
190                            let lit = super::assertion_synthetic::numeric_literal(val);
191                            let _ = writeln!(out, "    assert!({expr} > {lit}, \"expected > {lit}\");");
192                        }
193                    }
194                    "greater_than_or_equal" => {
195                        if let Some(val) = &assertion.value {
196                            let lit = super::assertion_synthetic::numeric_literal(val);
197                            let _ = writeln!(out, "    assert!({expr} >= {lit}, \"expected >= {lit}\");");
198                        }
199                    }
200                    "contains" => {
201                        if let Some(serde_json::Value::String(s)) = &assertion.value {
202                            let escaped = crate::escape::escape_rust(s);
203                            let _ = writeln!(
204                                out,
205                                "    assert!({expr}.contains(\"{escaped}\"), \"expected to contain: {escaped}\");"
206                            );
207                        }
208                    }
209                    _ => {
210                        let _ = writeln!(
211                            out,
212                            "    // streaming field '{f}': assertion type '{}' not rendered",
213                            assertion.assertion_type
214                        );
215                    }
216                }
217            }
218            return;
219        }
220    }
221
222    // Skip assertions on fields that don't exist on the result type.
223    // Exception: fields prefixed with "error." target the error value in error-context
224    // assertions — they are resolved against the error type via accessor_for_error,
225    // not against the success result type, so they must not be skipped here.
226    if let Some(f) = &assertion.field {
227        if !f.is_empty() && !f.starts_with("error.") && !field_resolver.is_valid_for_result(f) {
228            let _ = writeln!(out, "    // skipped: field '{f}' not available on result type");
229            return;
230        }
231    }
232
233    // Check if this field was unwrapped (i.e., it is optional and was bound to a local).
234    let is_unwrapped = assertion
235        .field
236        .as_ref()
237        .is_some_and(|f| unwrapped_fields.iter().any(|(ff, _)| ff == f));
238
239    // When in error context with returns_result=true and accessing a field (not an error check),
240    // we need to unwrap the Result first. The test generator creates a binding like
241    // `let result_ok = result.as_ref().ok();` which we can dereference here.
242    // Exception: fields prefixed with "error." access the Err value, not the Ok value.
243    let has_field = assertion.field.as_ref().is_some_and(|f| !f.is_empty());
244    let is_field_assertion = !matches!(assertion.assertion_type.as_str(), "error" | "not_error");
245    let is_error_field = assertion.field.as_ref().is_some_and(|f| f.starts_with("error."));
246    let effective_result_var =
247        if has_field && is_error_context && returns_result && is_field_assertion && !is_error_field {
248            // Dereference the Option<&T> bound as {result_var}_ok
249            format!("{result_var}_ok.as_ref().unwrap()")
250        } else {
251            result_var.to_string()
252        };
253
254    // Determine field access expression:
255    // 1. If the field was unwrapped to a local var, use that local var name.
256    // 2. When result_is_simple, the function returns a plain type (String etc.) — use result_var.
257    // 3. When the field path is exactly the result var name (sentinel: `field: "result"`),
258    //    refer to the result variable directly to avoid emitting `result.result`.
259    // 4. When the result is a Tree, map pseudo-field names to correct Rust expressions.
260    // 5. When the field starts with "error.", resolve against the error type.
261    // 6. Otherwise, use the field resolver to generate the accessor.
262    let field_access = match &assertion.field {
263        Some(f) if !f.is_empty() => {
264            if let Some((_, local_var)) = unwrapped_fields.iter().find(|(ff, _)| ff == f) {
265                local_var.clone()
266            } else if result_is_simple && !f.starts_with("error.") {
267                // Plain return type (String, Vec<T>, etc.) has no struct fields.
268                // Use the result variable directly so assertions operate on the value itself.
269                // Exception: error.* fields must resolve against the Err value, not the
270                // plain result variable, even when the success type is simple (e.g. Bytes).
271                effective_result_var.clone()
272            } else if f == result_var {
273                // Sentinel: fixture uses `field: "result"` (or matches the result variable name)
274                // to refer to the whole return value, not a struct field named "result".
275                effective_result_var.clone()
276            } else if result_is_tree {
277                // Tree is an opaque type — its "fields" are accessed via root_node() or
278                // free functions. Map known pseudo-field names to correct Rust expressions.
279                tree_field_access_expr(f, &effective_result_var, module)
280            } else if let Some(sub) = f.strip_prefix("error.") {
281                // Error-path field: access a field on the Err value rather than the Ok value.
282                // Inline-bind the error so the expression is self-contained.
283                let err_accessor = field_resolver.accessor_for_error(sub, "rust", "__err");
284                format!("{{ let __err = {result_var}.as_ref().err().unwrap(); {err_accessor} }}")
285            } else {
286                field_resolver.accessor(f, "rust", &effective_result_var)
287            }
288        }
289        _ => effective_result_var,
290    };
291
292    match assertion.assertion_type.as_str() {
293        "error" => {
294            let _ = writeln!(out, "    assert!({result_var}.is_err(), \"expected call to fail\");");
295            if let Some(serde_json::Value::String(msg)) = &assertion.value {
296                let escaped = escape_rust(msg);
297                // Match against the Debug format (variant-name-style) and the Display format
298                // (human-readable text). Fixtures often name the error variant ("BadRequest"),
299                // but Display impls typically lowercase with a colon ("bad request: ..."), so
300                // checking both lets either kind of fixture value match.
301                let _ = writeln!(
302                    out,
303                    "    {{ let __e = {result_var}.as_ref().err().unwrap(); assert!(format!(\"{{:?}}\", __e).contains(\"{escaped}\") || __e.to_string().contains(\"{escaped}\"), \"error message mismatch\"); }}"
304                );
305            }
306        }
307        "not_error" => {
308            // Handled at call site; nothing extra needed here.
309        }
310        "equals" => {
311            render_equals_assertion(out, assertion, &field_access, is_unwrapped, field_resolver);
312        }
313        "contains" => {
314            if let Some(val) = &assertion.value {
315                let expected = value_to_rust_string(val);
316                let line = format!(
317                    "    assert!(format!(\"{{:?}}\", {field_access}).contains({expected}), \"expected to contain: {{}}\", {expected});"
318                );
319                let _ = writeln!(out, "{line}");
320            }
321        }
322        "contains_all" => {
323            if let Some(values) = &assertion.values {
324                for val in values {
325                    let expected = value_to_rust_string(val);
326                    let line = format!(
327                        "    assert!(format!(\"{{:?}}\", {field_access}).contains({expected}), \"expected to contain: {{}}\", {expected});"
328                    );
329                    let _ = writeln!(out, "{line}");
330                }
331            }
332        }
333        "not_contains" => {
334            if let Some(val) = &assertion.value {
335                let expected = value_to_rust_string(val);
336                let line = format!(
337                    "    assert!(!format!(\"{{:?}}\", {field_access}).contains({expected}), \"expected NOT to contain: {{}}\", {expected});"
338                );
339                let _ = writeln!(out, "{line}");
340            }
341        }
342        "not_empty" => {
343            render_not_empty_assertion(
344                out,
345                assertion,
346                &field_access,
347                result_var,
348                result_is_option,
349                is_unwrapped,
350                field_resolver,
351            );
352        }
353        "is_empty" => {
354            render_is_empty_assertion(out, assertion, &field_access, is_unwrapped, field_resolver);
355        }
356        "contains_any" => {
357            if let Some(values) = &assertion.values {
358                let checks: Vec<String> = values
359                    .iter()
360                    .map(|v| {
361                        let expected = value_to_rust_string(v);
362                        format!("{field_access}.contains({expected})")
363                    })
364                    .collect();
365                let joined = checks.join(" || ");
366                let _ = writeln!(
367                    out,
368                    "    assert!({joined}, \"expected to contain at least one of the specified values\");"
369                );
370            }
371        }
372        "greater_than" => {
373            if let Some(val) = &assertion.value {
374                // Skip comparisons with negative values against unsigned types (.len() etc.)
375                if val.as_f64().is_some_and(|n| n < 0.0) {
376                    let _ = writeln!(
377                        out,
378                        "    // skipped: greater_than with negative value is always true for unsigned types"
379                    );
380                } else if val.as_u64() == Some(0) {
381                    if field_access.ends_with(".len()") {
382                        // Clippy prefers !is_empty() over len() > 0 for collections.
383                        let base = field_access.strip_suffix(".len()").unwrap();
384                        let _ = writeln!(out, "    assert!(!{base}.is_empty(), \"expected > 0\");");
385                    } else {
386                        // Scalar types (usize, u64, etc.) — use direct comparison.
387                        let _ = writeln!(out, "    assert!({field_access} > 0, \"expected > 0\");");
388                    }
389                } else {
390                    let lit = numeric_literal(val);
391                    let _ = writeln!(out, "    assert!({field_access} > {lit}, \"expected > {lit}\");");
392                }
393            }
394        }
395        "less_than" => {
396            if let Some(val) = &assertion.value {
397                let lit = numeric_literal(val);
398                let _ = writeln!(out, "    assert!({field_access} < {lit}, \"expected < {lit}\");");
399            }
400        }
401        "greater_than_or_equal" => {
402            render_gte_assertion(out, assertion, &field_access, is_unwrapped, field_resolver);
403        }
404        "less_than_or_equal" => {
405            if let Some(val) = &assertion.value {
406                let lit = numeric_literal(val);
407                let _ = writeln!(out, "    assert!({field_access} <= {lit}, \"expected <= {lit}\");");
408            }
409        }
410        "starts_with" => {
411            if let Some(val) = &assertion.value {
412                let expected = value_to_rust_string(val);
413                let _ = writeln!(
414                    out,
415                    "    assert!({field_access}.starts_with({expected}), \"expected to start with: {{}}\", {expected});"
416                );
417            }
418        }
419        "ends_with" => {
420            if let Some(val) = &assertion.value {
421                let expected = value_to_rust_string(val);
422                let _ = writeln!(
423                    out,
424                    "    assert!({field_access}.ends_with({expected}), \"expected to end with: {{}}\", {expected});"
425                );
426            }
427        }
428        "min_length" => {
429            if let Some(val) = &assertion.value {
430                if let Some(n) = val.as_u64() {
431                    let _ = writeln!(
432                        out,
433                        "    assert!({field_access}.len() >= {n}, \"expected length >= {n}, got {{}}\", {field_access}.len());"
434                    );
435                }
436            }
437        }
438        "max_length" => {
439            if let Some(val) = &assertion.value {
440                if let Some(n) = val.as_u64() {
441                    let _ = writeln!(
442                        out,
443                        "    assert!({field_access}.len() <= {n}, \"expected length <= {n}, got {{}}\", {field_access}.len());"
444                    );
445                }
446            }
447        }
448        "count_min" => {
449            render_count_min_assertion(out, assertion, &field_access, is_unwrapped, field_resolver);
450        }
451        "count_equals" => {
452            render_count_equals_assertion(out, assertion, &field_access, is_unwrapped, field_resolver);
453        }
454        "is_true" => {
455            let _ = writeln!(out, "    assert!({field_access}, \"expected true\");");
456        }
457        "is_false" => {
458            let _ = writeln!(out, "    assert!(!{field_access}, \"expected false\");");
459        }
460        "method_result" => {
461            render_method_result_assertion(out, assertion, &field_access, result_is_tree, module);
462        }
463        other => {
464            panic!("Rust e2e generator: unsupported assertion type: {other}");
465        }
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use std::collections::{HashMap, HashSet};
472
473    use super::*;
474    use crate::field_access::FieldResolver;
475    use crate::fixture::Assertion;
476
477    fn empty_resolver() -> FieldResolver {
478        FieldResolver::new(
479            &HashMap::new(),
480            &HashSet::new(),
481            &HashSet::new(),
482            &HashSet::new(),
483            &HashSet::new(),
484        )
485    }
486
487    fn make_assertion(assertion_type: &str, field: Option<&str>, value: Option<serde_json::Value>) -> Assertion {
488        Assertion {
489            assertion_type: assertion_type.to_string(),
490            field: field.map(|s| s.to_string()),
491            value,
492            ..Default::default()
493        }
494    }
495
496    #[test]
497    fn render_assertion_error_type_emits_is_err_check() {
498        let resolver = empty_resolver();
499        let assertion = make_assertion("error", None, None);
500        let mut out = String::new();
501        render_assertion(
502            &mut out,
503            &assertion,
504            "result",
505            "my_mod",
506            "dep",
507            true,
508            &[],
509            &resolver,
510            false,
511            false,
512            false,
513            false,
514            false,
515        );
516        assert!(out.contains("is_err()"), "got: {out}");
517    }
518
519    #[test]
520    fn render_assertion_vec_result_wraps_in_for_loop() {
521        let resolver = empty_resolver();
522        let assertion = make_assertion("not_empty", Some("content"), None);
523        let mut out = String::new();
524        render_assertion(
525            &mut out,
526            &assertion,
527            "result",
528            "my_mod",
529            "dep",
530            false,
531            &[],
532            &resolver,
533            false,
534            false,
535            true,
536            false,
537            false,
538        );
539        assert!(out.contains("for r in"), "got: {out}");
540    }
541
542    #[test]
543    fn render_assertion_not_empty_bare_result_uses_is_empty() {
544        let resolver = empty_resolver();
545        let assertion = make_assertion("not_empty", None, None);
546        let mut out = String::new();
547        render_assertion(
548            &mut out,
549            &assertion,
550            "result",
551            "my_mod",
552            "dep",
553            false,
554            &[],
555            &resolver,
556            false,
557            false,
558            false,
559            false,
560            false,
561        );
562        assert!(out.contains("is_empty()"), "got: {out}");
563    }
564}