Skip to main content

prosaic_project/
style.rs

1//! `prosaic.toml` schema for the `[style_profile]` section.
2//!
3//! `StyleProfileConfig` is the TOML-friendly mirror of
4//! [`prosaic_core::StyleProfile`]. It uses string keys for RST-relation
5//! pools (TOML can't key tables by Rust enum variants directly), keeps
6//! every field optional so projects can declare just the dials they care
7//! about, and supports a single `extends = "path"` reference to another
8//! profile TOML for file-level composition.
9
10use std::collections::HashMap;
11use std::path::{Path, PathBuf};
12
13use prosaic_core::{
14    ConnectivePreferences, HedgingCalibration, LengthDistribution, ListStyleBias, PronounDensity,
15    RstRelation, SalienceBias, StyleProfile, StyleProfileError, Verbosity,
16};
17use serde::{Deserialize, Serialize};
18
19use crate::error::ProjectError;
20
21/// TOML representation of a [`StyleProfile`].
22///
23/// Every field is optional so an authoring file can declare just the dials
24/// it cares about; missing fields fall through to the neutral default.
25/// `extends = "path"` loads another `StyleProfileConfig` as the base; the
26/// inline `Some(_)` fields then override per-dial.
27#[derive(Debug, Default, Clone, Serialize, Deserialize)]
28#[serde(default)]
29pub struct StyleProfileConfig {
30    /// Optional path (relative to the manifest's directory) to another
31    /// profile TOML to use as the base. Inline fields below override the
32    /// base per-dial.
33    pub extends: Option<String>,
34    pub name: Option<String>,
35    pub verbosity: Option<String>,
36    pub list_style_bias: Option<String>,
37    pub pronoun_density: Option<String>,
38    pub salience: Option<String>,
39    pub sentence_length: Option<LengthDistributionConfig>,
40    pub connectives: Option<ConnectivePreferencesConfig>,
41    pub hedging: Option<HedgingCalibrationConfig>,
42}
43
44#[derive(Debug, Default, Clone, Serialize, Deserialize)]
45#[serde(default)]
46pub struct LengthDistributionConfig {
47    pub short: Option<f32>,
48    pub medium: Option<f32>,
49    pub long: Option<f32>,
50    pub short_max_words: Option<u16>,
51    pub medium_max_words: Option<u16>,
52}
53
54#[derive(Debug, Default, Clone, Serialize, Deserialize)]
55#[serde(default)]
56pub struct ConnectivePreferencesConfig {
57    /// Per-RST-relation allowed connective pools. Keys are lowercase
58    /// `RstRelation` variant names: `elaboration`, `contrast`, `cause`,
59    /// `result`, `concession`, `sequence`, `condition`, `background`,
60    /// `summary`. Unknown keys are rejected at parse time so typos
61    /// surface as clear errors instead of silently no-op'ing.
62    pub allowed: Option<HashMap<String, Vec<String>>>,
63    /// Per-RST-relation tie-breaker weights. Inner shape is array of
64    /// `[connective, weight]` pairs.
65    pub preferred: Option<HashMap<String, Vec<(String, f32)>>>,
66}
67
68#[derive(Debug, Default, Clone, Serialize, Deserialize)]
69#[serde(default)]
70pub struct HedgingCalibrationConfig {
71    pub offset: Option<i8>,
72    pub forbid: Option<Vec<String>>,
73}
74
75impl StyleProfileConfig {
76    /// Convert this TOML config into a validated [`StyleProfile`].
77    ///
78    /// `manifest_dir` is the directory containing the project's
79    /// `prosaic.toml`; relative `extends` paths resolve from there. The
80    /// returned profile has been run through [`StyleProfile::validate`]
81    /// — invalid configurations surface as `ProjectError::ManifestStyle`
82    /// rather than silently producing a half-built profile.
83    pub fn into_style_profile(self, manifest_dir: &Path) -> Result<StyleProfile, ProjectError> {
84        let merged = self.resolve(manifest_dir, &mut Vec::new())?;
85        merged.build_profile()
86    }
87
88    fn resolve(
89        self,
90        manifest_dir: &Path,
91        seen: &mut Vec<PathBuf>,
92    ) -> Result<StyleProfileConfig, ProjectError> {
93        // Walk `extends` chain depth-first, base-first. The accumulated
94        // base config gets overlaid by self's `Some(_)` fields at the end.
95        let base = if let Some(ext_path) = &self.extends {
96            let mut path = manifest_dir.join(ext_path);
97            if !path.is_absolute() {
98                path = manifest_dir.join(ext_path);
99            }
100            let canonical = path.canonicalize().unwrap_or(path.clone());
101            if seen.iter().any(|p| p == &canonical) {
102                return Err(ProjectError::ManifestStyle {
103                    reason: format!(
104                        "extends cycle detected: `{}` is already in the resolution chain",
105                        path.display()
106                    ),
107                });
108            }
109            seen.push(canonical);
110            let text = std::fs::read_to_string(&path).map_err(|e| ProjectError::Io {
111                path: path.display().to_string(),
112                cause: e.to_string(),
113            })?;
114            let parent = path
115                .parent()
116                .map(Path::to_path_buf)
117                .unwrap_or_else(|| manifest_dir.to_path_buf());
118            let parsed: StyleProfileConfig =
119                toml::from_str(&text).map_err(|e| ProjectError::TomlParse {
120                    file: path.display().to_string(),
121                    cause: e.to_string(),
122                })?;
123            Some(parsed.resolve(&parent, seen)?)
124        } else {
125            None
126        };
127
128        Ok(merge_overlay(base.unwrap_or_default(), self))
129    }
130
131    fn build_profile(self) -> Result<StyleProfile, ProjectError> {
132        let mut builder =
133            StyleProfile::builder(self.name.unwrap_or_else(|| String::from("default")));
134        if let Some(v) = self.verbosity {
135            builder = builder.verbosity(parse_verbosity(&v)?);
136        }
137        if let Some(l) = self.list_style_bias {
138            builder = builder.list_style_bias(parse_list_style_bias(&l)?);
139        }
140        if let Some(p) = self.pronoun_density {
141            builder = builder.pronoun_density(parse_pronoun_density(&p)?);
142        }
143        if let Some(s) = self.salience {
144            builder = builder.salience(parse_salience_bias(&s)?);
145        }
146        if let Some(sl) = self.sentence_length {
147            builder = builder.sentence_length(build_length_distribution(sl));
148        }
149        if let Some(c) = self.connectives {
150            builder = builder.connectives(build_connective_preferences(c)?);
151        }
152        if let Some(h) = self.hedging {
153            builder = builder.hedging(build_hedging_calibration(h));
154        }
155        builder.build().map_err(map_style_error)
156    }
157}
158
159fn merge_overlay(base: StyleProfileConfig, overlay: StyleProfileConfig) -> StyleProfileConfig {
160    StyleProfileConfig {
161        extends: None, // resolved already
162        name: overlay.name.or(base.name),
163        verbosity: overlay.verbosity.or(base.verbosity),
164        list_style_bias: overlay.list_style_bias.or(base.list_style_bias),
165        pronoun_density: overlay.pronoun_density.or(base.pronoun_density),
166        salience: overlay.salience.or(base.salience),
167        sentence_length: merge_length(base.sentence_length, overlay.sentence_length),
168        connectives: merge_connectives(base.connectives, overlay.connectives),
169        hedging: merge_hedging(base.hedging, overlay.hedging),
170    }
171}
172
173fn merge_length(
174    base: Option<LengthDistributionConfig>,
175    overlay: Option<LengthDistributionConfig>,
176) -> Option<LengthDistributionConfig> {
177    match (base, overlay) {
178        (None, o) => o,
179        (b, None) => b,
180        (Some(b), Some(o)) => Some(LengthDistributionConfig {
181            short: o.short.or(b.short),
182            medium: o.medium.or(b.medium),
183            long: o.long.or(b.long),
184            short_max_words: o.short_max_words.or(b.short_max_words),
185            medium_max_words: o.medium_max_words.or(b.medium_max_words),
186        }),
187    }
188}
189
190fn merge_connectives(
191    base: Option<ConnectivePreferencesConfig>,
192    overlay: Option<ConnectivePreferencesConfig>,
193) -> Option<ConnectivePreferencesConfig> {
194    match (base, overlay) {
195        (None, o) => o,
196        (b, None) => b,
197        (Some(b), Some(o)) => Some(ConnectivePreferencesConfig {
198            allowed: o.allowed.or(b.allowed),
199            preferred: o.preferred.or(b.preferred),
200        }),
201    }
202}
203
204fn merge_hedging(
205    base: Option<HedgingCalibrationConfig>,
206    overlay: Option<HedgingCalibrationConfig>,
207) -> Option<HedgingCalibrationConfig> {
208    match (base, overlay) {
209        (None, o) => o,
210        (b, None) => b,
211        (Some(b), Some(o)) => Some(HedgingCalibrationConfig {
212            offset: o.offset.or(b.offset),
213            forbid: o.forbid.or(b.forbid),
214        }),
215    }
216}
217
218fn build_length_distribution(c: LengthDistributionConfig) -> LengthDistribution {
219    let neutral = LengthDistribution::neutral();
220    LengthDistribution {
221        short: c.short.unwrap_or(neutral.short),
222        medium: c.medium.unwrap_or(neutral.medium),
223        long: c.long.unwrap_or(neutral.long),
224        short_max_words: c.short_max_words.unwrap_or(neutral.short_max_words),
225        medium_max_words: c.medium_max_words.unwrap_or(neutral.medium_max_words),
226    }
227}
228
229fn build_connective_preferences(
230    c: ConnectivePreferencesConfig,
231) -> Result<ConnectivePreferences, ProjectError> {
232    let mut prefs = ConnectivePreferences::neutral();
233    if let Some(allowed) = c.allowed {
234        for (k, v) in allowed {
235            let rst = parse_rst_relation(&k)?;
236            prefs.allowed.insert(rst, v);
237        }
238    }
239    if let Some(preferred) = c.preferred {
240        for (k, v) in preferred {
241            let rst = parse_rst_relation(&k)?;
242            prefs.preferred.insert(rst, v);
243        }
244    }
245    Ok(prefs)
246}
247
248fn build_hedging_calibration(c: HedgingCalibrationConfig) -> HedgingCalibration {
249    HedgingCalibration {
250        offset: c.offset.unwrap_or(0),
251        forbid: c.forbid.unwrap_or_default(),
252    }
253}
254
255fn parse_verbosity(s: &str) -> Result<Verbosity, ProjectError> {
256    match s {
257        "terse" => Ok(Verbosity::Terse),
258        "neutral" => Ok(Verbosity::Neutral),
259        "verbose" => Ok(Verbosity::Verbose),
260        other => Err(ProjectError::ManifestStyle {
261            reason: format!(
262                "unknown verbosity `{other}` — expected one of terse, neutral, verbose"
263            ),
264        }),
265    }
266}
267
268fn parse_list_style_bias(s: &str) -> Result<ListStyleBias, ProjectError> {
269    match s {
270        "auto" => Ok(ListStyleBias::Auto),
271        "including" => Ok(ListStyleBias::Including),
272        "such_as" => Ok(ListStyleBias::SuchAs),
273        "dash" => Ok(ListStyleBias::Dash),
274        "bracketed" => Ok(ListStyleBias::Bracketed),
275        other => Err(ProjectError::ManifestStyle {
276            reason: format!(
277                "unknown list_style_bias `{other}` — expected one of auto, including, such_as, dash, bracketed"
278            ),
279        }),
280    }
281}
282
283fn parse_pronoun_density(s: &str) -> Result<PronounDensity, ProjectError> {
284    match s {
285        "low" => Ok(PronounDensity::Low),
286        "default" => Ok(PronounDensity::Default),
287        "high" => Ok(PronounDensity::High),
288        other => Err(ProjectError::ManifestStyle {
289            reason: format!(
290                "unknown pronoun_density `{other}` — expected one of low, default, high"
291            ),
292        }),
293    }
294}
295
296fn parse_salience_bias(s: &str) -> Result<SalienceBias, ProjectError> {
297    match s {
298        "lower" => Ok(SalienceBias::Lower),
299        "auto" => Ok(SalienceBias::Auto),
300        "higher" => Ok(SalienceBias::Higher),
301        other => Err(ProjectError::ManifestStyle {
302            reason: format!(
303                "unknown salience bias `{other}` — expected one of lower, auto, higher"
304            ),
305        }),
306    }
307}
308
309fn parse_rst_relation(s: &str) -> Result<RstRelation, ProjectError> {
310    match s {
311        "elaboration" => Ok(RstRelation::Elaboration),
312        "contrast" => Ok(RstRelation::Contrast),
313        "cause" => Ok(RstRelation::Cause),
314        "result" => Ok(RstRelation::Result),
315        "concession" => Ok(RstRelation::Concession),
316        "sequence" => Ok(RstRelation::Sequence),
317        "condition" => Ok(RstRelation::Condition),
318        "background" => Ok(RstRelation::Background),
319        "summary" => Ok(RstRelation::Summary),
320        other => Err(ProjectError::ManifestStyle {
321            reason: format!(
322                "unknown RST relation key `{other}` — expected one of elaboration, contrast, cause, result, concession, sequence, condition, background, summary"
323            ),
324        }),
325    }
326}
327
328fn map_style_error(err: StyleProfileError) -> ProjectError {
329    ProjectError::ManifestStyle {
330        reason: err.to_string(),
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337    use std::fs;
338    use tempfile::tempdir;
339
340    #[test]
341    fn parses_minimal_inline_profile() {
342        let toml_str = r#"
343            name = "concise"
344            verbosity = "terse"
345            list_style_bias = "bracketed"
346        "#;
347        let cfg: StyleProfileConfig = toml::from_str(toml_str).unwrap();
348        let dir = tempdir().unwrap();
349        let profile = cfg.into_style_profile(dir.path()).unwrap();
350        assert_eq!(profile.name, "concise");
351        assert_eq!(profile.verbosity, Verbosity::Terse);
352        assert_eq!(profile.list_style_bias, ListStyleBias::Bracketed);
353        assert!(profile.connectives.is_neutral());
354    }
355
356    #[test]
357    fn parses_per_relation_connective_pools() {
358        let toml_str = r#"
359            name = "tight-contrast"
360            [connectives.allowed]
361            elaboration = ["Furthermore,", "Additionally,"]
362            contrast = ["However,"]
363            [connectives.preferred]
364            elaboration = [["Furthermore,", 1.0], ["Additionally,", 0.5]]
365        "#;
366        let cfg: StyleProfileConfig = toml::from_str(toml_str).unwrap();
367        let dir = tempdir().unwrap();
368        let profile = cfg.into_style_profile(dir.path()).unwrap();
369        assert_eq!(
370            profile
371                .connectives
372                .allowed
373                .get(&RstRelation::Elaboration)
374                .map(Vec::len),
375            Some(2)
376        );
377        assert_eq!(
378            profile
379                .connectives
380                .allowed
381                .get(&RstRelation::Contrast)
382                .map(Vec::len),
383            Some(1)
384        );
385        assert_eq!(
386            profile
387                .connectives
388                .preferred
389                .get(&RstRelation::Elaboration)
390                .map(Vec::len),
391            Some(2)
392        );
393    }
394
395    #[test]
396    fn unknown_rst_relation_key_is_rejected() {
397        let toml_str = r#"
398            name = "bad"
399            [connectives.allowed]
400            shrubbery = ["foo"]
401        "#;
402        let cfg: StyleProfileConfig = toml::from_str(toml_str).unwrap();
403        let dir = tempdir().unwrap();
404        let result = cfg.into_style_profile(dir.path());
405        assert!(matches!(
406            result,
407            Err(ProjectError::ManifestStyle { reason }) if reason.contains("shrubbery")
408        ));
409    }
410
411    #[test]
412    fn unknown_verbosity_value_is_rejected() {
413        let toml_str = r#"
414            name = "bad"
415            verbosity = "yelly"
416        "#;
417        let cfg: StyleProfileConfig = toml::from_str(toml_str).unwrap();
418        let dir = tempdir().unwrap();
419        let result = cfg.into_style_profile(dir.path());
420        assert!(matches!(
421            result,
422            Err(ProjectError::ManifestStyle { reason }) if reason.contains("yelly")
423        ));
424    }
425
426    #[test]
427    fn extends_loads_referenced_profile_and_overlays() {
428        let dir = tempdir().unwrap();
429        let base_path = dir.path().join("base.toml");
430        fs::write(
431            &base_path,
432            r#"
433                name = "base"
434                verbosity = "terse"
435                list_style_bias = "bracketed"
436            "#,
437        )
438        .unwrap();
439
440        // Inline overrides only verbosity; list_style_bias stays from base.
441        let overlay_toml = r#"
442            extends = "base.toml"
443            name = "child"
444            verbosity = "verbose"
445        "#;
446        let cfg: StyleProfileConfig = toml::from_str(overlay_toml).unwrap();
447        let profile = cfg.into_style_profile(dir.path()).unwrap();
448        assert_eq!(profile.name, "child");
449        assert_eq!(profile.verbosity, Verbosity::Verbose);
450        assert_eq!(profile.list_style_bias, ListStyleBias::Bracketed);
451    }
452
453    #[test]
454    fn extends_cycle_is_rejected() {
455        let dir = tempdir().unwrap();
456        fs::write(
457            dir.path().join("a.toml"),
458            r#"
459                extends = "b.toml"
460                name = "a"
461            "#,
462        )
463        .unwrap();
464        fs::write(
465            dir.path().join("b.toml"),
466            r#"
467                extends = "a.toml"
468                name = "b"
469            "#,
470        )
471        .unwrap();
472        let cfg = StyleProfileConfig {
473            extends: Some("a.toml".to_string()),
474            ..Default::default()
475        };
476        let result = cfg.into_style_profile(dir.path());
477        assert!(matches!(
478            result,
479            Err(ProjectError::ManifestStyle { reason }) if reason.contains("cycle")
480        ));
481    }
482
483    #[test]
484    fn validation_errors_propagate() {
485        let toml_str = r#"
486            name = "bad"
487            [hedging]
488            offset = 75
489        "#;
490        let cfg: StyleProfileConfig = toml::from_str(toml_str).unwrap();
491        let dir = tempdir().unwrap();
492        let result = cfg.into_style_profile(dir.path());
493        assert!(matches!(
494            result,
495            Err(ProjectError::ManifestStyle { reason }) if reason.contains("75")
496        ));
497    }
498
499    #[test]
500    fn empty_config_produces_neutral_profile() {
501        let cfg = StyleProfileConfig::default();
502        let dir = tempdir().unwrap();
503        let profile = cfg.into_style_profile(dir.path()).unwrap();
504        assert!(profile.is_neutral());
505    }
506}