1use std::collections::HashMap;
4use std::fs;
5use std::path::{Path, PathBuf};
6
7use prosaic_core::Context;
8
9use crate::error::ProjectError;
10use crate::fixture::parse_fixture;
11use crate::manifest::Manifest;
12use crate::partial::PartialFile;
13use crate::scenario::Scenario;
14use crate::template::TemplateFile;
15
16#[derive(Debug, Clone)]
17pub struct Project {
18 pub root: PathBuf,
19 pub manifest: Manifest,
20 pub templates: HashMap<String, TemplateFile>,
21 pub partials: HashMap<String, PartialFile>,
22 pub fixtures: HashMap<String, Context>,
23 pub scenarios: HashMap<String, Scenario>,
24}
25
26#[derive(Debug, Clone)]
27pub struct ValidationIssue {
28 pub level: ValidationLevel,
29 pub location: String,
30 pub message: String,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum ValidationLevel {
35 Error,
36 Warning,
37}
38
39impl Project {
40 pub fn load_from_dir(path: impl AsRef<Path>) -> Result<Self, ProjectError> {
41 let root = path.as_ref().to_path_buf();
42
43 let manifest_path = root.join("prosaic.toml");
44 if !manifest_path.exists() {
45 return Err(ProjectError::ManifestMissing {
46 path: manifest_path.display().to_string(),
47 });
48 }
49 let manifest_str = fs::read_to_string(&manifest_path).map_err(|e| ProjectError::Io {
50 path: manifest_path.display().to_string(),
51 cause: e.to_string(),
52 })?;
53 let manifest: Manifest =
54 toml::from_str(&manifest_str).map_err(|e| ProjectError::TomlParse {
55 file: "prosaic.toml".to_string(),
56 cause: e.to_string(),
57 })?;
58
59 let templates =
60 load_toml_dir::<TemplateFile, _>(&root.join("templates"), |t| t.key.clone())?;
61 let partials = load_toml_dir::<PartialFile, _>(&root.join("partials"), |p| p.name.clone())?;
62 let scenarios = load_toml_dir::<Scenario, _>(&root.join("tests"), |s| s.name.clone())?;
63 let fixtures = load_fixtures_dir(&root.join("fixtures"))?;
64
65 Ok(Project {
66 root,
67 manifest,
68 templates,
69 partials,
70 fixtures,
71 scenarios,
72 })
73 }
74}
75
76fn load_toml_dir<T, F>(dir: &Path, key_fn: F) -> Result<HashMap<String, T>, ProjectError>
77where
78 T: serde::de::DeserializeOwned,
79 F: Fn(&T) -> String,
80{
81 let mut out = HashMap::new();
82 if !dir.exists() {
83 return Ok(out);
84 }
85 for entry in fs::read_dir(dir).map_err(|e| ProjectError::Io {
86 path: dir.display().to_string(),
87 cause: e.to_string(),
88 })? {
89 let entry = entry.map_err(|e| ProjectError::Io {
90 path: dir.display().to_string(),
91 cause: e.to_string(),
92 })?;
93 let path = entry.path();
94 if path.extension().map(|e| e == "toml").unwrap_or(false) {
95 let text = fs::read_to_string(&path).map_err(|e| ProjectError::Io {
96 path: path.display().to_string(),
97 cause: e.to_string(),
98 })?;
99 let parsed: T = toml::from_str(&text).map_err(|e| ProjectError::TomlParse {
100 file: path.file_name().unwrap().to_string_lossy().to_string(),
101 cause: e.to_string(),
102 })?;
103 let key = key_fn(&parsed);
104 out.insert(key, parsed);
105 }
106 }
107 Ok(out)
108}
109
110use prosaic_core::{Engine, Salience, SalienceThresholds, Strictness, Variation};
111use prosaic_grammar_en::English;
112
113const KNOWN_PIPES: &[&str] = &[
114 "plural",
115 "pluralize",
116 "article",
117 "join",
118 "ordinal",
119 "words",
120 "truncate",
121 "capitalize",
122 "refer",
123 "verb",
124 "syn",
125 "relative",
126 "since_last",
127 "quantify",
128 "proportion",
129 "hedge",
130 "negated",
131 "choose",
132 "demonstrative",
133];
134
135impl Project {
136 pub fn validate(&self) -> Vec<ValidationIssue> {
139 let mut issues = Vec::new();
140 let known_partials: std::collections::HashSet<_> = self.partials.keys().cloned().collect();
141
142 for (key, template) in &self.templates {
143 for (vi, variant) in template.variants.iter().enumerate() {
144 let parsed = match prosaic_core::Template::parse(&variant.body) {
145 Ok(p) => p,
146 Err(e) => {
147 issues.push(ValidationIssue {
148 level: ValidationLevel::Error,
149 location: format!("templates/{key}.toml#variant[{vi}]"),
150 message: format!("template parse error: {e}"),
151 });
152 continue;
153 }
154 };
155 for pipe_name in parsed.pipe_names() {
156 if !KNOWN_PIPES.contains(&pipe_name.as_str()) {
157 issues.push(ValidationIssue {
158 level: ValidationLevel::Error,
159 location: format!("templates/{key}.toml#variant[{vi}]"),
160 message: format!("unknown pipe `{pipe_name}`"),
161 });
162 }
163 }
164 for partial_name in parsed.partial_names() {
165 if !known_partials.contains(&partial_name) {
166 issues.push(ValidationIssue {
167 level: ValidationLevel::Error,
168 location: format!("templates/{key}.toml#variant[{vi}]"),
169 message: format!("unknown partial `{partial_name}`"),
170 });
171 }
172 }
173 }
174 }
175
176 issues
177 }
178
179 pub fn save_template(&self, key: &str) -> Result<(), ProjectError> {
181 let template = self
182 .templates
183 .get(key)
184 .ok_or_else(|| ProjectError::TemplateValidation {
185 key: key.to_string(),
186 reason: "template not present in project".to_string(),
187 })?;
188 let dir = self.root.join("templates");
189 if !dir.exists() {
190 fs::create_dir_all(&dir).map_err(|e| ProjectError::Io {
191 path: dir.display().to_string(),
192 cause: e.to_string(),
193 })?;
194 }
195 let serialized = toml::to_string_pretty(template).map_err(|e| ProjectError::TomlParse {
196 file: format!("{key}.toml"),
197 cause: e.to_string(),
198 })?;
199 let path = dir.join(format!("{key}.toml"));
200 fs::write(&path, serialized).map_err(|e| ProjectError::Io {
201 path: path.display().to_string(),
202 cause: e.to_string(),
203 })
204 }
205
206 pub fn save_partial(&self, name: &str) -> Result<(), ProjectError> {
208 let partial = self
209 .partials
210 .get(name)
211 .ok_or_else(|| ProjectError::PartialValidation {
212 name: name.to_string(),
213 reason: "partial not present in project".to_string(),
214 })?;
215 let dir = self.root.join("partials");
216 if !dir.exists() {
217 fs::create_dir_all(&dir).map_err(|e| ProjectError::Io {
218 path: dir.display().to_string(),
219 cause: e.to_string(),
220 })?;
221 }
222 let serialized = toml::to_string_pretty(partial).map_err(|e| ProjectError::TomlParse {
223 file: format!("{name}.toml"),
224 cause: e.to_string(),
225 })?;
226 let path = dir.join(format!("{name}.toml"));
227 fs::write(&path, serialized).map_err(|e| ProjectError::Io {
228 path: path.display().to_string(),
229 cause: e.to_string(),
230 })
231 }
232
233 pub fn save_scenario(&self, name: &str) -> Result<(), ProjectError> {
235 let scenario =
236 self.scenarios
237 .get(name)
238 .ok_or_else(|| ProjectError::ScenarioValidation {
239 name: name.to_string(),
240 reason: "scenario not present in project".to_string(),
241 })?;
242 let dir = self.root.join("tests");
243 if !dir.exists() {
244 fs::create_dir_all(&dir).map_err(|e| ProjectError::Io {
245 path: dir.display().to_string(),
246 cause: e.to_string(),
247 })?;
248 }
249 let serialized = toml::to_string_pretty(scenario).map_err(|e| ProjectError::TomlParse {
250 file: format!("{name}.toml"),
251 cause: e.to_string(),
252 })?;
253 let path = dir.join(format!("{name}.toml"));
254 fs::write(&path, serialized).map_err(|e| ProjectError::Io {
255 path: path.display().to_string(),
256 cause: e.to_string(),
257 })
258 }
259}
260
261impl Project {
262 pub fn into_engine(&self) -> Result<Engine, ProjectError> {
267 let mut engine = Engine::new(English::new());
268
269 let s = &self.manifest.engine;
270 engine = match s.strictness.as_str() {
271 "strict" => engine.strictness(Strictness::Strict),
272 "lenient" => engine.strictness(Strictness::Lenient),
273 "silent" => engine.strictness(Strictness::Silent),
274 other => {
275 return Err(ProjectError::TemplateValidation {
276 key: "(manifest)".to_string(),
277 reason: format!("unknown strictness `{other}`"),
278 });
279 }
280 };
281 engine = match s.variation.as_str() {
282 "fixed" => engine.variation(Variation::Fixed),
283 "round_robin" => engine.variation(Variation::RoundRobin),
284 "random" => engine.variation(Variation::Random),
285 other => {
286 return Err(ProjectError::TemplateValidation {
287 key: "(manifest)".to_string(),
288 reason: format!("unknown variation `{other}`"),
289 });
290 }
291 };
292 if s.smart_quotes {
293 engine = engine.smart_quotes(true);
294 }
295 if s.max_sentence_length > 0 {
296 engine = engine.max_sentence_length(s.max_sentence_length);
297 }
298 if s.faithfulness_min > 0.0 {
299 engine = engine.with_faithfulness_gate(s.faithfulness_min as f32);
300 }
301 if let Some(thr) = &s.salience_thresholds {
302 engine = engine.salience_thresholds(SalienceThresholds {
303 low_max: thr.low_max,
304 high_min: thr.high_min,
305 });
306 }
307 if let Some(style) = &s.style {
308 engine = engine.style_preference(style);
309 }
310 if let Some(profile_cfg) = &self.manifest.style_profile {
311 let profile = profile_cfg.clone().into_style_profile(&self.root)?;
312 engine = engine.style_profile(profile);
313 }
314 engine = engine.language_preference(&self.manifest.language);
315
316 for (name, partial) in &self.partials {
317 engine.register_partial(name, &partial.body).map_err(|e| {
318 ProjectError::PartialValidation {
319 name: name.clone(),
320 reason: e.to_string(),
321 }
322 })?;
323 }
324
325 for (key, template) in &self.templates {
326 for variant in &template.variants {
327 let salience = match variant.salience.as_str() {
328 "low" => Salience::Low,
329 "medium" => Salience::Medium,
330 "high" => Salience::High,
331 other => {
332 return Err(ProjectError::TemplateValidation {
333 key: key.clone(),
334 reason: format!("unknown salience `{other}`"),
335 });
336 }
337 };
338 let language = variant.language.as_deref();
339 let style = variant.style.as_deref();
340 engine
341 .register_template_with_language_and_style_at(
342 key,
343 &variant.body,
344 salience,
345 language,
346 style,
347 )
348 .map_err(|e| ProjectError::TemplateValidation {
349 key: key.clone(),
350 reason: e.to_string(),
351 })?;
352 }
353 }
354
355 Ok(engine)
356 }
357}
358
359fn load_fixtures_dir(dir: &Path) -> Result<HashMap<String, Context>, ProjectError> {
360 let mut out = HashMap::new();
361 if !dir.exists() {
362 return Ok(out);
363 }
364 for entry in fs::read_dir(dir).map_err(|e| ProjectError::Io {
365 path: dir.display().to_string(),
366 cause: e.to_string(),
367 })? {
368 let entry = entry.map_err(|e| ProjectError::Io {
369 path: dir.display().to_string(),
370 cause: e.to_string(),
371 })?;
372 let path = entry.path();
373 if path.extension().map(|e| e == "json").unwrap_or(false) {
374 let stem = path
375 .file_stem()
376 .ok_or_else(|| ProjectError::Io {
377 path: path.display().to_string(),
378 cause: "file has no stem".to_string(),
379 })?
380 .to_string_lossy()
381 .to_string();
382 let text = fs::read_to_string(&path).map_err(|e| ProjectError::Io {
383 path: path.display().to_string(),
384 cause: e.to_string(),
385 })?;
386 let ctx = parse_fixture(&stem, &text)?;
387 out.insert(stem, ctx);
388 }
389 }
390 Ok(out)
391}