1use prosaic_core::{
5 Context, Engine, RenderExplanation, RstRelation, Session, Template, Value, score_faithfulness,
6};
7
8use crate::error::ProjectError;
9use crate::scenario::{Expected, ExpectedDiscourse, Scenario, ScenarioEvent};
10
11#[derive(Debug, Clone)]
12pub struct ScenarioOutcome {
13 pub scenario_name: String,
14 pub verdict: ScenarioVerdict,
15 pub actual_output: String,
16 pub event_outputs: Vec<String>,
17 pub failures: Vec<String>,
18}
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum ScenarioVerdict {
22 Pass,
23 Fail,
24}
25
26pub struct ScenarioRunner<'a> {
27 engine: &'a Engine,
28}
29
30impl<'a> ScenarioRunner<'a> {
31 pub fn new(engine: &'a Engine) -> Self {
32 Self { engine }
33 }
34
35 pub fn run(&self, scenario: &Scenario) -> Result<ScenarioOutcome, ProjectError> {
36 reject_unsupported_engine_overrides(scenario)?;
37
38 let mut event_data = Vec::with_capacity(scenario.events.len());
39 for event in &scenario.events {
40 let ctx = scenario_event_to_context(event);
41 let relation = parse_rst_relation(&scenario.name, event)?;
42 event_data.push((event.template.clone(), ctx, relation));
43 }
44
45 let mut explain_session = Session::new();
46 let mut event_outputs = Vec::with_capacity(event_data.len());
47 let mut explanations = Vec::with_capacity(event_data.len());
48 let mut faithfulness_scores = Vec::with_capacity(event_data.len());
49
50 for (template, ctx, _) in &event_data {
51 let explanation = self
52 .engine
53 .render_explained(&mut explain_session, template, ctx)
54 .map_err(|e| ProjectError::ScenarioValidation {
55 name: scenario.name.clone(),
56 reason: format!("event template `{template}`: {e}"),
57 })?;
58 let parsed_template = Template::parse(&explanation.variant_source).map_err(|e| {
59 ProjectError::ScenarioValidation {
60 name: scenario.name.clone(),
61 reason: format!("selected template `{template}` could not be parsed: {e}"),
62 }
63 })?;
64 let literals = parsed_template.literal_tokens();
65 let score =
66 score_faithfulness(&explanation.output, ctx, &literals, self.engine.language());
67 event_outputs.push(explanation.output.clone());
68 explanations.push(explanation);
69 faithfulness_scores.push(score);
70 }
71
72 let actual_output = if event_data.iter().any(|(_, _, relation)| relation.is_some()) {
73 let mut session = Session::new();
74 let batch: Vec<(&str, Context, Option<RstRelation>)> = event_data
75 .iter()
76 .map(|(template, ctx, relation)| (template.as_str(), ctx.clone(), *relation))
77 .collect();
78 self.engine
79 .render_batch_with_relations(&mut session, &batch)
80 .map_err(|e| ProjectError::ScenarioValidation {
81 name: scenario.name.clone(),
82 reason: format!("rendering RST scenario: {e}"),
83 })?
84 } else {
85 event_outputs.join(" ")
86 };
87
88 let mut failures = Vec::new();
89 if let Some(expected) = &scenario.expected {
90 check_expected(
91 expected,
92 &actual_output,
93 &explanations,
94 &faithfulness_scores,
95 &mut failures,
96 );
97 }
98 if let Some(min) = scenario.engine.faithfulness_min {
99 check_faithfulness_min(min as f32, &faithfulness_scores, &mut failures);
100 }
101
102 let verdict = if failures.is_empty() {
103 ScenarioVerdict::Pass
104 } else {
105 ScenarioVerdict::Fail
106 };
107
108 Ok(ScenarioOutcome {
109 scenario_name: scenario.name.clone(),
110 verdict,
111 actual_output,
112 event_outputs,
113 failures,
114 })
115 }
116}
117
118fn reject_unsupported_engine_overrides(scenario: &Scenario) -> Result<(), ProjectError> {
119 if scenario.engine.variation.is_some()
120 || scenario.engine.language.is_some()
121 || scenario.engine.salience_thresholds.is_some()
122 {
123 return Err(ProjectError::ScenarioValidation {
124 name: scenario.name.clone(),
125 reason: "scenario engine overrides for variation, language, and salience_thresholds are not supported by ScenarioRunner; configure the Project engine instead".to_string(),
126 });
127 }
128 Ok(())
129}
130
131fn scenario_event_to_context(event: &ScenarioEvent) -> Context {
132 let mut ctx = Context::new();
133 for (k, v) in &event.context {
134 ctx.insert(k.clone(), toml_to_value(v));
135 }
136 ctx
137}
138
139fn toml_to_value(v: &toml::Value) -> Value {
140 use toml::Value as TV;
141 match v {
142 TV::String(s) => Value::String(s.clone()),
143 TV::Integer(i) => Value::Number(*i),
144 TV::Float(f) => Value::Number(*f as i64),
145 TV::Boolean(b) => Value::Number(if *b { 1 } else { 0 }),
146 TV::Array(items) => Value::List(
147 items
148 .iter()
149 .map(|i| match i {
150 TV::String(s) => s.clone(),
151 other => other.to_string(),
152 })
153 .collect(),
154 ),
155 _ => Value::String(v.to_string()),
156 }
157}
158
159fn parse_rst_relation(
160 scenario_name: &str,
161 event: &ScenarioEvent,
162) -> Result<Option<RstRelation>, ProjectError> {
163 let Some(raw) = &event.rst_relation else {
164 return Ok(None);
165 };
166 let normalized = raw.trim().to_ascii_lowercase().replace(['-', ' '], "_");
167 let relation = match normalized.as_str() {
168 "elaboration" => RstRelation::Elaboration,
169 "contrast" => RstRelation::Contrast,
170 "cause" => RstRelation::Cause,
171 "result" => RstRelation::Result,
172 "concession" => RstRelation::Concession,
173 "sequence" => RstRelation::Sequence,
174 "condition" => RstRelation::Condition,
175 "background" => RstRelation::Background,
176 "summary" => RstRelation::Summary,
177 other => {
178 return Err(ProjectError::ScenarioValidation {
179 name: scenario_name.to_string(),
180 reason: format!("unknown rst_relation `{other}`"),
181 });
182 }
183 };
184 Ok(Some(relation))
185}
186
187fn check_expected(
188 expected: &Expected,
189 actual: &str,
190 explanations: &[RenderExplanation],
191 faithfulness_scores: &[prosaic_core::FaithfulnessScore],
192 failures: &mut Vec<String>,
193) {
194 if let Some(ref out) = expected.output {
195 let actual_norm = actual.split_whitespace().collect::<Vec<_>>().join(" ");
196 let expected_norm = out.split_whitespace().collect::<Vec<_>>().join(" ");
197 if actual_norm != expected_norm {
198 failures.push(format!(
199 "output mismatch:\n expected: {expected_norm}\n actual: {actual_norm}"
200 ));
201 }
202 }
203 if let Some(min) = expected.faithfulness_min {
204 check_faithfulness_min(min as f32, faithfulness_scores, failures);
205 }
206 for discourse in &expected.discourse {
207 check_expected_discourse(discourse, explanations, actual, failures);
208 }
209}
210
211fn check_faithfulness_min(
212 min: f32,
213 scores: &[prosaic_core::FaithfulnessScore],
214 failures: &mut Vec<String>,
215) {
216 for (idx, score) in scores.iter().enumerate() {
217 if !score.passes(min) {
218 failures.push(format!(
219 "faithfulness below threshold at event {idx}: precision={:.3}, polarity_match={}, required={min:.3}, unentailed={:?}",
220 score.precision, score.polarity_match, score.unentailed
221 ));
222 }
223 }
224}
225
226fn check_expected_discourse(
227 expected: &ExpectedDiscourse,
228 explanations: &[RenderExplanation],
229 actual_output: &str,
230 failures: &mut Vec<String>,
231) {
232 let Some(explanation) = explanations.get(expected.event_index) else {
233 failures.push(format!(
234 "discourse expectation references missing event index {}",
235 expected.event_index
236 ));
237 return;
238 };
239
240 if let Some(reference_form) = &expected.reference_form {
241 let actual = explanation
242 .reference_form
243 .map(|form| format!("{form:?}"))
244 .unwrap_or_else(|| "None".to_string());
245 if !actual.eq_ignore_ascii_case(reference_form) {
246 failures.push(format!(
247 "reference_form mismatch at event {}: expected {reference_form}, actual {actual}",
248 expected.event_index
249 ));
250 }
251 }
252
253 if let Some(needle) = &expected.connective_contains {
254 let connective = explanation.connective.unwrap_or("");
255 if !connective.contains(needle)
256 && !explanation.output.contains(needle)
257 && !actual_output.contains(needle)
258 {
259 failures.push(format!(
260 "connective mismatch at event {}: expected text containing `{needle}`, actual connective={:?}, output={}",
261 expected.event_index, explanation.connective, explanation.output
262 ));
263 }
264 }
265
266 if let Some(transition) = &expected.transition {
267 let actual = format!("{:?}", explanation.centering_transition);
268 if !actual.eq_ignore_ascii_case(transition) {
269 failures.push(format!(
270 "transition mismatch at event {}: expected {transition}, actual {actual}",
271 expected.event_index
272 ));
273 }
274 }
275}