Skip to main content

prosaic_project/
runner.rs

1//! Scenario runner — render a scenario through one Session and check
2//! its output and discourse assertions against expectations.
3
4use prosaic_core::{
5    Context, Engine, RenderExplanation, RstRelation, Session, Template, Value, score_faithfulness,
6};
7
8use crate::error::ProjectError;
9use crate::scenario::{Expected, ExpectedDiscourse, Scenario, ScenarioEvent};
10
11#[derive(Debug, Clone)]
12pub struct ScenarioOutcome {
13    pub scenario_name: String,
14    pub verdict: ScenarioVerdict,
15    pub actual_output: String,
16    pub event_outputs: Vec<String>,
17    pub failures: Vec<String>,
18}
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum ScenarioVerdict {
22    Pass,
23    Fail,
24}
25
26pub struct ScenarioRunner<'a> {
27    engine: &'a Engine,
28}
29
30impl<'a> ScenarioRunner<'a> {
31    pub fn new(engine: &'a Engine) -> Self {
32        Self { engine }
33    }
34
35    pub fn run(&self, scenario: &Scenario) -> Result<ScenarioOutcome, ProjectError> {
36        reject_unsupported_engine_overrides(scenario)?;
37
38        let mut event_data = Vec::with_capacity(scenario.events.len());
39        for event in &scenario.events {
40            let ctx = scenario_event_to_context(event);
41            let relation = parse_rst_relation(&scenario.name, event)?;
42            event_data.push((event.template.clone(), ctx, relation));
43        }
44
45        let mut explain_session = Session::new();
46        let mut event_outputs = Vec::with_capacity(event_data.len());
47        let mut explanations = Vec::with_capacity(event_data.len());
48        let mut faithfulness_scores = Vec::with_capacity(event_data.len());
49
50        for (template, ctx, _) in &event_data {
51            let explanation = self
52                .engine
53                .render_explained(&mut explain_session, template, ctx)
54                .map_err(|e| ProjectError::ScenarioValidation {
55                    name: scenario.name.clone(),
56                    reason: format!("event template `{template}`: {e}"),
57                })?;
58            let parsed_template = Template::parse(&explanation.variant_source).map_err(|e| {
59                ProjectError::ScenarioValidation {
60                    name: scenario.name.clone(),
61                    reason: format!("selected template `{template}` could not be parsed: {e}"),
62                }
63            })?;
64            let literals = parsed_template.literal_tokens();
65            let score =
66                score_faithfulness(&explanation.output, ctx, &literals, self.engine.language());
67            event_outputs.push(explanation.output.clone());
68            explanations.push(explanation);
69            faithfulness_scores.push(score);
70        }
71
72        let actual_output = if event_data.iter().any(|(_, _, relation)| relation.is_some()) {
73            let mut session = Session::new();
74            let batch: Vec<(&str, Context, Option<RstRelation>)> = event_data
75                .iter()
76                .map(|(template, ctx, relation)| (template.as_str(), ctx.clone(), *relation))
77                .collect();
78            self.engine
79                .render_batch_with_relations(&mut session, &batch)
80                .map_err(|e| ProjectError::ScenarioValidation {
81                    name: scenario.name.clone(),
82                    reason: format!("rendering RST scenario: {e}"),
83                })?
84        } else {
85            event_outputs.join(" ")
86        };
87
88        let mut failures = Vec::new();
89        if let Some(expected) = &scenario.expected {
90            check_expected(
91                expected,
92                &actual_output,
93                &explanations,
94                &faithfulness_scores,
95                &mut failures,
96            );
97        }
98        if let Some(min) = scenario.engine.faithfulness_min {
99            check_faithfulness_min(min as f32, &faithfulness_scores, &mut failures);
100        }
101
102        let verdict = if failures.is_empty() {
103            ScenarioVerdict::Pass
104        } else {
105            ScenarioVerdict::Fail
106        };
107
108        Ok(ScenarioOutcome {
109            scenario_name: scenario.name.clone(),
110            verdict,
111            actual_output,
112            event_outputs,
113            failures,
114        })
115    }
116}
117
118fn reject_unsupported_engine_overrides(scenario: &Scenario) -> Result<(), ProjectError> {
119    if scenario.engine.variation.is_some()
120        || scenario.engine.language.is_some()
121        || scenario.engine.salience_thresholds.is_some()
122    {
123        return Err(ProjectError::ScenarioValidation {
124            name: scenario.name.clone(),
125            reason: "scenario engine overrides for variation, language, and salience_thresholds are not supported by ScenarioRunner; configure the Project engine instead".to_string(),
126        });
127    }
128    Ok(())
129}
130
131fn scenario_event_to_context(event: &ScenarioEvent) -> Context {
132    let mut ctx = Context::new();
133    for (k, v) in &event.context {
134        ctx.insert(k.clone(), toml_to_value(v));
135    }
136    ctx
137}
138
139fn toml_to_value(v: &toml::Value) -> Value {
140    use toml::Value as TV;
141    match v {
142        TV::String(s) => Value::String(s.clone()),
143        TV::Integer(i) => Value::Number(*i),
144        TV::Float(f) => Value::Number(*f as i64),
145        TV::Boolean(b) => Value::Number(if *b { 1 } else { 0 }),
146        TV::Array(items) => Value::List(
147            items
148                .iter()
149                .map(|i| match i {
150                    TV::String(s) => s.clone(),
151                    other => other.to_string(),
152                })
153                .collect(),
154        ),
155        _ => Value::String(v.to_string()),
156    }
157}
158
159fn parse_rst_relation(
160    scenario_name: &str,
161    event: &ScenarioEvent,
162) -> Result<Option<RstRelation>, ProjectError> {
163    let Some(raw) = &event.rst_relation else {
164        return Ok(None);
165    };
166    let normalized = raw.trim().to_ascii_lowercase().replace(['-', ' '], "_");
167    let relation = match normalized.as_str() {
168        "elaboration" => RstRelation::Elaboration,
169        "contrast" => RstRelation::Contrast,
170        "cause" => RstRelation::Cause,
171        "result" => RstRelation::Result,
172        "concession" => RstRelation::Concession,
173        "sequence" => RstRelation::Sequence,
174        "condition" => RstRelation::Condition,
175        "background" => RstRelation::Background,
176        "summary" => RstRelation::Summary,
177        other => {
178            return Err(ProjectError::ScenarioValidation {
179                name: scenario_name.to_string(),
180                reason: format!("unknown rst_relation `{other}`"),
181            });
182        }
183    };
184    Ok(Some(relation))
185}
186
187fn check_expected(
188    expected: &Expected,
189    actual: &str,
190    explanations: &[RenderExplanation],
191    faithfulness_scores: &[prosaic_core::FaithfulnessScore],
192    failures: &mut Vec<String>,
193) {
194    if let Some(ref out) = expected.output {
195        let actual_norm = actual.split_whitespace().collect::<Vec<_>>().join(" ");
196        let expected_norm = out.split_whitespace().collect::<Vec<_>>().join(" ");
197        if actual_norm != expected_norm {
198            failures.push(format!(
199                "output mismatch:\n  expected: {expected_norm}\n  actual:   {actual_norm}"
200            ));
201        }
202    }
203    if let Some(min) = expected.faithfulness_min {
204        check_faithfulness_min(min as f32, faithfulness_scores, failures);
205    }
206    for discourse in &expected.discourse {
207        check_expected_discourse(discourse, explanations, actual, failures);
208    }
209}
210
211fn check_faithfulness_min(
212    min: f32,
213    scores: &[prosaic_core::FaithfulnessScore],
214    failures: &mut Vec<String>,
215) {
216    for (idx, score) in scores.iter().enumerate() {
217        if !score.passes(min) {
218            failures.push(format!(
219                "faithfulness below threshold at event {idx}: precision={:.3}, polarity_match={}, required={min:.3}, unentailed={:?}",
220                score.precision, score.polarity_match, score.unentailed
221            ));
222        }
223    }
224}
225
226fn check_expected_discourse(
227    expected: &ExpectedDiscourse,
228    explanations: &[RenderExplanation],
229    actual_output: &str,
230    failures: &mut Vec<String>,
231) {
232    let Some(explanation) = explanations.get(expected.event_index) else {
233        failures.push(format!(
234            "discourse expectation references missing event index {}",
235            expected.event_index
236        ));
237        return;
238    };
239
240    if let Some(reference_form) = &expected.reference_form {
241        let actual = explanation
242            .reference_form
243            .map(|form| format!("{form:?}"))
244            .unwrap_or_else(|| "None".to_string());
245        if !actual.eq_ignore_ascii_case(reference_form) {
246            failures.push(format!(
247                "reference_form mismatch at event {}: expected {reference_form}, actual {actual}",
248                expected.event_index
249            ));
250        }
251    }
252
253    if let Some(needle) = &expected.connective_contains {
254        let connective = explanation.connective.unwrap_or("");
255        if !connective.contains(needle)
256            && !explanation.output.contains(needle)
257            && !actual_output.contains(needle)
258        {
259            failures.push(format!(
260                "connective mismatch at event {}: expected text containing `{needle}`, actual connective={:?}, output={}",
261                expected.event_index, explanation.connective, explanation.output
262            ));
263        }
264    }
265
266    if let Some(transition) = &expected.transition {
267        let actual = format!("{:?}", explanation.centering_transition);
268        if !actual.eq_ignore_ascii_case(transition) {
269            failures.push(format!(
270                "transition mismatch at event {}: expected {transition}, actual {actual}",
271                expected.event_index
272            ));
273        }
274    }
275}