1use std::collections::HashMap;
11use std::path::{Path, PathBuf};
12
13use prosaic_core::{
14 ConnectivePreferences, HedgingCalibration, LengthDistribution, ListStyleBias, PronounDensity,
15 RstRelation, SalienceBias, StyleProfile, StyleProfileError, Verbosity,
16};
17use serde::{Deserialize, Serialize};
18
19use crate::error::ProjectError;
20
21#[derive(Debug, Default, Clone, Serialize, Deserialize)]
28#[serde(default)]
29pub struct StyleProfileConfig {
30 pub extends: Option<String>,
34 pub name: Option<String>,
35 pub verbosity: Option<String>,
36 pub list_style_bias: Option<String>,
37 pub pronoun_density: Option<String>,
38 pub salience: Option<String>,
39 pub sentence_length: Option<LengthDistributionConfig>,
40 pub connectives: Option<ConnectivePreferencesConfig>,
41 pub hedging: Option<HedgingCalibrationConfig>,
42}
43
44#[derive(Debug, Default, Clone, Serialize, Deserialize)]
45#[serde(default)]
46pub struct LengthDistributionConfig {
47 pub short: Option<f32>,
48 pub medium: Option<f32>,
49 pub long: Option<f32>,
50 pub short_max_words: Option<u16>,
51 pub medium_max_words: Option<u16>,
52}
53
54#[derive(Debug, Default, Clone, Serialize, Deserialize)]
55#[serde(default)]
56pub struct ConnectivePreferencesConfig {
57 pub allowed: Option<HashMap<String, Vec<String>>>,
63 pub preferred: Option<HashMap<String, Vec<(String, f32)>>>,
66}
67
68#[derive(Debug, Default, Clone, Serialize, Deserialize)]
69#[serde(default)]
70pub struct HedgingCalibrationConfig {
71 pub offset: Option<i8>,
72 pub forbid: Option<Vec<String>>,
73}
74
75impl StyleProfileConfig {
76 pub fn into_style_profile(self, manifest_dir: &Path) -> Result<StyleProfile, ProjectError> {
84 let merged = self.resolve(manifest_dir, &mut Vec::new())?;
85 merged.build_profile()
86 }
87
88 fn resolve(
89 self,
90 manifest_dir: &Path,
91 seen: &mut Vec<PathBuf>,
92 ) -> Result<StyleProfileConfig, ProjectError> {
93 let base = if let Some(ext_path) = &self.extends {
96 let mut path = manifest_dir.join(ext_path);
97 if !path.is_absolute() {
98 path = manifest_dir.join(ext_path);
99 }
100 let canonical = path.canonicalize().unwrap_or(path.clone());
101 if seen.iter().any(|p| p == &canonical) {
102 return Err(ProjectError::ManifestStyle {
103 reason: format!(
104 "extends cycle detected: `{}` is already in the resolution chain",
105 path.display()
106 ),
107 });
108 }
109 seen.push(canonical);
110 let text = std::fs::read_to_string(&path).map_err(|e| ProjectError::Io {
111 path: path.display().to_string(),
112 cause: e.to_string(),
113 })?;
114 let parent = path
115 .parent()
116 .map(Path::to_path_buf)
117 .unwrap_or_else(|| manifest_dir.to_path_buf());
118 let parsed: StyleProfileConfig =
119 toml::from_str(&text).map_err(|e| ProjectError::TomlParse {
120 file: path.display().to_string(),
121 cause: e.to_string(),
122 })?;
123 Some(parsed.resolve(&parent, seen)?)
124 } else {
125 None
126 };
127
128 Ok(merge_overlay(base.unwrap_or_default(), self))
129 }
130
131 fn build_profile(self) -> Result<StyleProfile, ProjectError> {
132 let mut builder =
133 StyleProfile::builder(self.name.unwrap_or_else(|| String::from("default")));
134 if let Some(v) = self.verbosity {
135 builder = builder.verbosity(parse_verbosity(&v)?);
136 }
137 if let Some(l) = self.list_style_bias {
138 builder = builder.list_style_bias(parse_list_style_bias(&l)?);
139 }
140 if let Some(p) = self.pronoun_density {
141 builder = builder.pronoun_density(parse_pronoun_density(&p)?);
142 }
143 if let Some(s) = self.salience {
144 builder = builder.salience(parse_salience_bias(&s)?);
145 }
146 if let Some(sl) = self.sentence_length {
147 builder = builder.sentence_length(build_length_distribution(sl));
148 }
149 if let Some(c) = self.connectives {
150 builder = builder.connectives(build_connective_preferences(c)?);
151 }
152 if let Some(h) = self.hedging {
153 builder = builder.hedging(build_hedging_calibration(h));
154 }
155 builder.build().map_err(map_style_error)
156 }
157}
158
159fn merge_overlay(base: StyleProfileConfig, overlay: StyleProfileConfig) -> StyleProfileConfig {
160 StyleProfileConfig {
161 extends: None, name: overlay.name.or(base.name),
163 verbosity: overlay.verbosity.or(base.verbosity),
164 list_style_bias: overlay.list_style_bias.or(base.list_style_bias),
165 pronoun_density: overlay.pronoun_density.or(base.pronoun_density),
166 salience: overlay.salience.or(base.salience),
167 sentence_length: merge_length(base.sentence_length, overlay.sentence_length),
168 connectives: merge_connectives(base.connectives, overlay.connectives),
169 hedging: merge_hedging(base.hedging, overlay.hedging),
170 }
171}
172
173fn merge_length(
174 base: Option<LengthDistributionConfig>,
175 overlay: Option<LengthDistributionConfig>,
176) -> Option<LengthDistributionConfig> {
177 match (base, overlay) {
178 (None, o) => o,
179 (b, None) => b,
180 (Some(b), Some(o)) => Some(LengthDistributionConfig {
181 short: o.short.or(b.short),
182 medium: o.medium.or(b.medium),
183 long: o.long.or(b.long),
184 short_max_words: o.short_max_words.or(b.short_max_words),
185 medium_max_words: o.medium_max_words.or(b.medium_max_words),
186 }),
187 }
188}
189
190fn merge_connectives(
191 base: Option<ConnectivePreferencesConfig>,
192 overlay: Option<ConnectivePreferencesConfig>,
193) -> Option<ConnectivePreferencesConfig> {
194 match (base, overlay) {
195 (None, o) => o,
196 (b, None) => b,
197 (Some(b), Some(o)) => Some(ConnectivePreferencesConfig {
198 allowed: o.allowed.or(b.allowed),
199 preferred: o.preferred.or(b.preferred),
200 }),
201 }
202}
203
204fn merge_hedging(
205 base: Option<HedgingCalibrationConfig>,
206 overlay: Option<HedgingCalibrationConfig>,
207) -> Option<HedgingCalibrationConfig> {
208 match (base, overlay) {
209 (None, o) => o,
210 (b, None) => b,
211 (Some(b), Some(o)) => Some(HedgingCalibrationConfig {
212 offset: o.offset.or(b.offset),
213 forbid: o.forbid.or(b.forbid),
214 }),
215 }
216}
217
218fn build_length_distribution(c: LengthDistributionConfig) -> LengthDistribution {
219 let neutral = LengthDistribution::neutral();
220 LengthDistribution {
221 short: c.short.unwrap_or(neutral.short),
222 medium: c.medium.unwrap_or(neutral.medium),
223 long: c.long.unwrap_or(neutral.long),
224 short_max_words: c.short_max_words.unwrap_or(neutral.short_max_words),
225 medium_max_words: c.medium_max_words.unwrap_or(neutral.medium_max_words),
226 }
227}
228
229fn build_connective_preferences(
230 c: ConnectivePreferencesConfig,
231) -> Result<ConnectivePreferences, ProjectError> {
232 let mut prefs = ConnectivePreferences::neutral();
233 if let Some(allowed) = c.allowed {
234 for (k, v) in allowed {
235 let rst = parse_rst_relation(&k)?;
236 prefs.allowed.insert(rst, v);
237 }
238 }
239 if let Some(preferred) = c.preferred {
240 for (k, v) in preferred {
241 let rst = parse_rst_relation(&k)?;
242 prefs.preferred.insert(rst, v);
243 }
244 }
245 Ok(prefs)
246}
247
248fn build_hedging_calibration(c: HedgingCalibrationConfig) -> HedgingCalibration {
249 HedgingCalibration {
250 offset: c.offset.unwrap_or(0),
251 forbid: c.forbid.unwrap_or_default(),
252 }
253}
254
255fn parse_verbosity(s: &str) -> Result<Verbosity, ProjectError> {
256 match s {
257 "terse" => Ok(Verbosity::Terse),
258 "neutral" => Ok(Verbosity::Neutral),
259 "verbose" => Ok(Verbosity::Verbose),
260 other => Err(ProjectError::ManifestStyle {
261 reason: format!(
262 "unknown verbosity `{other}` — expected one of terse, neutral, verbose"
263 ),
264 }),
265 }
266}
267
268fn parse_list_style_bias(s: &str) -> Result<ListStyleBias, ProjectError> {
269 match s {
270 "auto" => Ok(ListStyleBias::Auto),
271 "including" => Ok(ListStyleBias::Including),
272 "such_as" => Ok(ListStyleBias::SuchAs),
273 "dash" => Ok(ListStyleBias::Dash),
274 "bracketed" => Ok(ListStyleBias::Bracketed),
275 other => Err(ProjectError::ManifestStyle {
276 reason: format!(
277 "unknown list_style_bias `{other}` — expected one of auto, including, such_as, dash, bracketed"
278 ),
279 }),
280 }
281}
282
283fn parse_pronoun_density(s: &str) -> Result<PronounDensity, ProjectError> {
284 match s {
285 "low" => Ok(PronounDensity::Low),
286 "default" => Ok(PronounDensity::Default),
287 "high" => Ok(PronounDensity::High),
288 other => Err(ProjectError::ManifestStyle {
289 reason: format!(
290 "unknown pronoun_density `{other}` — expected one of low, default, high"
291 ),
292 }),
293 }
294}
295
296fn parse_salience_bias(s: &str) -> Result<SalienceBias, ProjectError> {
297 match s {
298 "lower" => Ok(SalienceBias::Lower),
299 "auto" => Ok(SalienceBias::Auto),
300 "higher" => Ok(SalienceBias::Higher),
301 other => Err(ProjectError::ManifestStyle {
302 reason: format!(
303 "unknown salience bias `{other}` — expected one of lower, auto, higher"
304 ),
305 }),
306 }
307}
308
309fn parse_rst_relation(s: &str) -> Result<RstRelation, ProjectError> {
310 match s {
311 "elaboration" => Ok(RstRelation::Elaboration),
312 "contrast" => Ok(RstRelation::Contrast),
313 "cause" => Ok(RstRelation::Cause),
314 "result" => Ok(RstRelation::Result),
315 "concession" => Ok(RstRelation::Concession),
316 "sequence" => Ok(RstRelation::Sequence),
317 "condition" => Ok(RstRelation::Condition),
318 "background" => Ok(RstRelation::Background),
319 "summary" => Ok(RstRelation::Summary),
320 other => Err(ProjectError::ManifestStyle {
321 reason: format!(
322 "unknown RST relation key `{other}` — expected one of elaboration, contrast, cause, result, concession, sequence, condition, background, summary"
323 ),
324 }),
325 }
326}
327
328fn map_style_error(err: StyleProfileError) -> ProjectError {
329 ProjectError::ManifestStyle {
330 reason: err.to_string(),
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use std::fs;
338 use tempfile::tempdir;
339
340 #[test]
341 fn parses_minimal_inline_profile() {
342 let toml_str = r#"
343 name = "concise"
344 verbosity = "terse"
345 list_style_bias = "bracketed"
346 "#;
347 let cfg: StyleProfileConfig = toml::from_str(toml_str).unwrap();
348 let dir = tempdir().unwrap();
349 let profile = cfg.into_style_profile(dir.path()).unwrap();
350 assert_eq!(profile.name, "concise");
351 assert_eq!(profile.verbosity, Verbosity::Terse);
352 assert_eq!(profile.list_style_bias, ListStyleBias::Bracketed);
353 assert!(profile.connectives.is_neutral());
354 }
355
356 #[test]
357 fn parses_per_relation_connective_pools() {
358 let toml_str = r#"
359 name = "tight-contrast"
360 [connectives.allowed]
361 elaboration = ["Furthermore,", "Additionally,"]
362 contrast = ["However,"]
363 [connectives.preferred]
364 elaboration = [["Furthermore,", 1.0], ["Additionally,", 0.5]]
365 "#;
366 let cfg: StyleProfileConfig = toml::from_str(toml_str).unwrap();
367 let dir = tempdir().unwrap();
368 let profile = cfg.into_style_profile(dir.path()).unwrap();
369 assert_eq!(
370 profile
371 .connectives
372 .allowed
373 .get(&RstRelation::Elaboration)
374 .map(Vec::len),
375 Some(2)
376 );
377 assert_eq!(
378 profile
379 .connectives
380 .allowed
381 .get(&RstRelation::Contrast)
382 .map(Vec::len),
383 Some(1)
384 );
385 assert_eq!(
386 profile
387 .connectives
388 .preferred
389 .get(&RstRelation::Elaboration)
390 .map(Vec::len),
391 Some(2)
392 );
393 }
394
395 #[test]
396 fn unknown_rst_relation_key_is_rejected() {
397 let toml_str = r#"
398 name = "bad"
399 [connectives.allowed]
400 shrubbery = ["foo"]
401 "#;
402 let cfg: StyleProfileConfig = toml::from_str(toml_str).unwrap();
403 let dir = tempdir().unwrap();
404 let result = cfg.into_style_profile(dir.path());
405 assert!(matches!(
406 result,
407 Err(ProjectError::ManifestStyle { reason }) if reason.contains("shrubbery")
408 ));
409 }
410
411 #[test]
412 fn unknown_verbosity_value_is_rejected() {
413 let toml_str = r#"
414 name = "bad"
415 verbosity = "yelly"
416 "#;
417 let cfg: StyleProfileConfig = toml::from_str(toml_str).unwrap();
418 let dir = tempdir().unwrap();
419 let result = cfg.into_style_profile(dir.path());
420 assert!(matches!(
421 result,
422 Err(ProjectError::ManifestStyle { reason }) if reason.contains("yelly")
423 ));
424 }
425
426 #[test]
427 fn extends_loads_referenced_profile_and_overlays() {
428 let dir = tempdir().unwrap();
429 let base_path = dir.path().join("base.toml");
430 fs::write(
431 &base_path,
432 r#"
433 name = "base"
434 verbosity = "terse"
435 list_style_bias = "bracketed"
436 "#,
437 )
438 .unwrap();
439
440 let overlay_toml = r#"
442 extends = "base.toml"
443 name = "child"
444 verbosity = "verbose"
445 "#;
446 let cfg: StyleProfileConfig = toml::from_str(overlay_toml).unwrap();
447 let profile = cfg.into_style_profile(dir.path()).unwrap();
448 assert_eq!(profile.name, "child");
449 assert_eq!(profile.verbosity, Verbosity::Verbose);
450 assert_eq!(profile.list_style_bias, ListStyleBias::Bracketed);
451 }
452
453 #[test]
454 fn extends_cycle_is_rejected() {
455 let dir = tempdir().unwrap();
456 fs::write(
457 dir.path().join("a.toml"),
458 r#"
459 extends = "b.toml"
460 name = "a"
461 "#,
462 )
463 .unwrap();
464 fs::write(
465 dir.path().join("b.toml"),
466 r#"
467 extends = "a.toml"
468 name = "b"
469 "#,
470 )
471 .unwrap();
472 let cfg = StyleProfileConfig {
473 extends: Some("a.toml".to_string()),
474 ..Default::default()
475 };
476 let result = cfg.into_style_profile(dir.path());
477 assert!(matches!(
478 result,
479 Err(ProjectError::ManifestStyle { reason }) if reason.contains("cycle")
480 ));
481 }
482
483 #[test]
484 fn validation_errors_propagate() {
485 let toml_str = r#"
486 name = "bad"
487 [hedging]
488 offset = 75
489 "#;
490 let cfg: StyleProfileConfig = toml::from_str(toml_str).unwrap();
491 let dir = tempdir().unwrap();
492 let result = cfg.into_style_profile(dir.path());
493 assert!(matches!(
494 result,
495 Err(ProjectError::ManifestStyle { reason }) if reason.contains("75")
496 ));
497 }
498
499 #[test]
500 fn empty_config_produces_neutral_profile() {
501 let cfg = StyleProfileConfig::default();
502 let dir = tempdir().unwrap();
503 let profile = cfg.into_style_profile(dir.path()).unwrap();
504 assert!(profile.is_neutral());
505 }
506}