1use std::collections::HashMap;
7use std::fs;
8use std::io::{self, BufRead};
9use std::path::{Path, PathBuf};
10
11use crate::symbol::{NumType, Symbol};
12
13pub use crate::udf::UserFunction;
15
16#[derive(Clone, Debug)]
18pub struct UserConstant {
19 #[allow(dead_code)]
24 pub weight: u32,
25 pub name: String,
27 #[allow(dead_code)]
31 pub description: String,
32 pub value: f64,
34 pub num_type: NumType,
36}
37
38#[derive(Clone, Debug, Default)]
40pub struct Profile {
41 pub constants: Vec<UserConstant>,
43 pub functions: Vec<UserFunction>,
45 pub symbol_names: HashMap<Symbol, String>,
47 pub symbol_weights: HashMap<Symbol, u32>,
49 pub includes: Vec<PathBuf>,
51}
52
53impl Profile {
54 pub fn new() -> Self {
56 Self::default()
57 }
58
59 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, ProfileError> {
61 load_profile_recursive(path.as_ref(), &mut Vec::new(), 0)
62 }
63
64 pub fn load_default() -> Self {
66 let mut profile = Profile::new();
67
68 if let Some(home) = dirs::home_dir() {
70 let home_profile = home.join(".ries_profile");
71 if home_profile.exists() {
72 if let Ok(p) = Self::from_file(&home_profile) {
73 profile = profile.merge(p);
74 }
75 }
76 }
77
78 let local_profile = PathBuf::from(".ries");
80 if local_profile.exists() {
81 if let Ok(p) = Self::from_file(&local_profile) {
82 profile = profile.merge(p);
83 }
84 }
85
86 profile
87 }
88
89 pub fn add_constant(
107 &mut self,
108 weight: u32,
109 name: String,
110 description: String,
111 value: f64,
112 ) -> Result<(), ProfileError> {
113 if name.is_empty() {
115 return Err(ProfileError::ValidationError(
116 "Constant name cannot be empty".to_string(),
117 ));
118 }
119
120 if !value.is_finite() {
122 return Err(ProfileError::ValidationError(format!(
123 "Constant value must be finite (got {})",
124 value
125 )));
126 }
127
128 let num_type = if value.fract() == 0.0 && value.abs() < 1e10 {
130 NumType::Integer
131 } else if is_rational(value) {
132 NumType::Rational
133 } else {
134 NumType::Transcendental
135 };
136
137 self.constants.push(UserConstant {
138 weight,
139 name,
140 description,
141 value,
142 num_type,
143 });
144
145 Ok(())
146 }
147
148 pub fn merge(mut self, other: Profile) -> Self {
150 for c in other.constants {
152 self.constants.retain(|existing| existing.name != c.name);
154 self.constants.push(c);
155 }
156
157 for f in other.functions {
159 self.functions.retain(|existing| existing.name != f.name);
161 self.functions.push(f);
162 }
163
164 self.symbol_names.extend(other.symbol_names);
166
167 self.symbol_weights.extend(other.symbol_weights);
169
170 self.includes.extend(other.includes);
172
173 self
174 }
175}
176
177const MAX_INCLUDE_DEPTH: usize = 25;
178
179fn load_profile_recursive(
180 path: &Path,
181 include_stack: &mut Vec<PathBuf>,
182 depth: usize,
183) -> Result<Profile, ProfileError> {
184 if depth > MAX_INCLUDE_DEPTH {
185 return Err(ProfileError::ParseError(
186 path.to_path_buf(),
187 0,
188 format!("Profile include depth exceeded {}", MAX_INCLUDE_DEPTH),
189 ));
190 }
191
192 let canonical = fs::canonicalize(path).unwrap_or_else(|_| path.to_path_buf());
193 if include_stack.contains(&canonical) {
194 return Err(ProfileError::ParseError(
195 path.to_path_buf(),
196 0,
197 "Recursive --include detected".to_string(),
198 ));
199 }
200 include_stack.push(canonical);
201
202 let file = fs::File::open(path).map_err(|e| ProfileError::IoError(path.to_path_buf(), e))?;
203 let mut profile = Profile::new();
204 let reader = io::BufReader::new(file);
205
206 for (line_num, line_result) in reader.lines().enumerate() {
207 let line = line_result.map_err(|e| ProfileError::IoError(path.to_path_buf(), e))?;
208 let trimmed = line.trim();
209 if trimmed.is_empty() || trimmed.starts_with('#') {
210 continue;
211 }
212
213 if trimmed.starts_with("--include") {
214 let include_raw = parse_include_path(trimmed)
215 .map_err(|e| ProfileError::ParseError(path.to_path_buf(), line_num + 1, e))?;
216
217 let include_resolved = resolve_include_path(path, &include_raw).ok_or_else(|| {
218 ProfileError::ParseError(
219 path.to_path_buf(),
220 line_num + 1,
221 format!(
222 "Could not open '{}' or '{}.ries' for reading",
223 include_raw.display(),
224 include_raw.display()
225 ),
226 )
227 })?;
228
229 profile.includes.push(include_resolved.clone());
230 let nested = load_profile_recursive(&include_resolved, include_stack, depth + 1)?;
231 profile = profile.merge(nested);
232 continue;
233 }
234
235 if let Err(e) = parse_profile_line(&mut profile, trimmed) {
236 return Err(ProfileError::ParseError(
237 path.to_path_buf(),
238 line_num + 1,
239 e,
240 ));
241 }
242 }
243
244 include_stack.pop();
245 Ok(profile)
246}
247
248fn resolve_include_path(current_file: &Path, include_path: &Path) -> Option<PathBuf> {
249 let mut candidates = Vec::new();
250
251 if include_path.is_absolute() {
252 candidates.push(include_path.to_path_buf());
253 } else {
254 let base = current_file.parent().unwrap_or_else(|| Path::new("."));
255 candidates.push(base.join(include_path));
256 }
257
258 let mut with_suffix = include_path.as_os_str().to_os_string();
259 with_suffix.push(".ries");
260 if include_path.is_absolute() {
261 candidates.push(PathBuf::from(with_suffix));
262 } else {
263 let base = current_file.parent().unwrap_or_else(|| Path::new("."));
264 candidates.push(base.join(PathBuf::from(with_suffix)));
265 }
266
267 candidates.into_iter().find(|p| p.exists())
268}
269
270fn parse_profile_line(profile: &mut Profile, line: &str) -> Result<(), String> {
272 if line.starts_with("-X") {
274 return parse_user_constant(profile, line);
275 }
276
277 if line.starts_with("--define") {
279 return parse_user_function(profile, line);
280 }
281
282 if line.starts_with("--symbol-names") {
284 return parse_symbol_names(profile, line);
285 }
286
287 if line.starts_with("--symbol-weights") {
289 return parse_symbol_weights(profile, line);
290 }
291
292 Ok(())
295}
296
297fn parse_user_constant(profile: &mut Profile, line: &str) -> Result<(), String> {
300 let rest = line[2..].trim();
302
303 let content = if let Some(stripped) = rest.strip_prefix('"') {
305 let end_quote = stripped.find('"').ok_or("Unclosed quote in -X directive")?;
307 &stripped[..end_quote]
308 } else {
309 rest
311 };
312
313 let parts: Vec<&str> = content.split(':').collect();
314 if parts.len() != 4 {
315 return Err(format!(
316 "Invalid -X format: expected 4 colon-separated parts, got {}",
317 parts.len()
318 ));
319 }
320
321 let weight: u32 = parts[0]
322 .parse()
323 .map_err(|_| format!("Invalid weight: {}", parts[0]))?;
324
325 let name = parts[1].to_string();
326 let description = parts[2].to_string();
327
328 let value: f64 = parts[3]
329 .parse()
330 .map_err(|_| format!("Invalid value: {}", parts[3]))?;
331
332 profile
334 .add_constant(weight, name, description, value)
335 .map_err(|e| e.to_string())?;
336
337 Ok(())
338}
339
340fn is_rational(v: f64) -> bool {
342 if !v.is_finite() || v == 0.0 {
343 return true;
344 }
345
346 for denom in 1..=100_u32 {
348 let numer = v * denom as f64;
349 if (numer.round() - numer).abs() < 1e-10 {
350 return true;
351 }
352 }
353 false
354}
355
356fn parse_user_function(profile: &mut Profile, line: &str) -> Result<(), String> {
359 let rest = line["--define".len()..].trim();
361
362 let content = if let Some(stripped) = rest.strip_prefix('"') {
364 let end_quote = stripped
366 .find('"')
367 .ok_or("Unclosed quote in --define directive")?;
368 &stripped[..end_quote]
369 } else {
370 rest
372 };
373
374 let udf = UserFunction::parse(content)?;
376 profile.functions.push(udf);
377
378 Ok(())
379}
380
381fn parse_symbol_names(profile: &mut Profile, line: &str) -> Result<(), String> {
384 let rest = line["--symbol-names".len()..].trim();
385
386 for part in rest.split_whitespace() {
387 if !part.starts_with(':') {
388 continue;
389 }
390
391 let inner = &part[1..];
392 if let Some(colon_pos) = inner.find(':') {
393 let symbol_char = inner[..colon_pos]
394 .chars()
395 .next()
396 .ok_or("Empty symbol in --symbol-names")?;
397 let name = inner[colon_pos + 1..].to_string();
398
399 if let Some(symbol) = Symbol::from_byte(symbol_char as u8) {
400 profile.symbol_names.insert(symbol, name);
401 }
402 }
403 }
404
405 Ok(())
406}
407
408fn parse_symbol_weights(profile: &mut Profile, line: &str) -> Result<(), String> {
411 let rest = line["--symbol-weights".len()..].trim();
412
413 for part in rest.split_whitespace() {
414 if !part.starts_with(':') {
415 continue;
416 }
417
418 let inner = &part[1..];
419 if let Some(colon_pos) = inner.find(':') {
420 let symbol_char = inner[..colon_pos]
421 .chars()
422 .next()
423 .ok_or("Empty symbol in --symbol-weights")?;
424 let weight: u32 = inner[colon_pos + 1..]
425 .parse()
426 .map_err(|_| format!("Invalid weight in --symbol-weights: {}", inner))?;
427
428 if let Some(symbol) = Symbol::from_byte(symbol_char as u8) {
429 profile.symbol_weights.insert(symbol, weight);
430 }
431 }
432 }
433
434 Ok(())
435}
436
437fn parse_include_path(line: &str) -> Result<PathBuf, String> {
440 let rest = line["--include".len()..].trim();
441
442 if rest.is_empty() {
443 return Err("--include requires a filename".to_string());
444 }
445
446 let path_str = if rest.starts_with('"') && rest.ends_with('"') {
448 &rest[1..rest.len() - 1]
449 } else {
450 rest
451 };
452
453 Ok(PathBuf::from(path_str))
454}
455
456#[derive(Debug)]
458pub enum ProfileError {
459 IoError(PathBuf, io::Error),
461 ParseError(PathBuf, usize, String),
463 ValidationError(String),
465}
466
467impl std::fmt::Display for ProfileError {
468 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469 match self {
470 ProfileError::IoError(path, e) => {
471 write!(f, "Error reading {}: {}", path.display(), e)
472 }
473 ProfileError::ParseError(path, line, msg) => {
474 write!(
475 f,
476 "Parse error in {} at line {}: {}",
477 path.display(),
478 line,
479 msg
480 )
481 }
482 ProfileError::ValidationError(msg) => {
483 write!(f, "Validation error: {}", msg)
484 }
485 }
486 }
487}
488
489impl std::error::Error for ProfileError {}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494
495 #[test]
496 fn test_parse_user_constant() {
497 let mut profile = Profile::new();
498 parse_user_constant(
499 &mut profile,
500 r#"-X "4:gamma:Euler's constant:0.5772156649""#,
501 )
502 .unwrap();
503
504 assert_eq!(profile.constants.len(), 1);
505 assert_eq!(profile.constants[0].name, "gamma");
506 assert_eq!(profile.constants[0].weight, 4);
507 assert!((profile.constants[0].value - 0.5772156649).abs() < 1e-10);
508 }
509
510 #[test]
511 fn test_parse_symbol_names() {
512 let mut profile = Profile::new();
513 parse_symbol_names(&mut profile, "--symbol-names :p:π :e:ℯ").unwrap();
514
515 assert_eq!(
516 profile.symbol_names.get(&Symbol::Pi),
517 Some(&"π".to_string())
518 );
519 assert_eq!(profile.symbol_names.get(&Symbol::E), Some(&"ℯ".to_string()));
520 }
521
522 #[test]
523 fn test_parse_symbol_weights() {
524 let mut profile = Profile::new();
525 parse_symbol_weights(&mut profile, "--symbol-weights :W:20 :p:25").unwrap();
526
527 assert_eq!(profile.symbol_weights.get(&Symbol::LambertW), Some(&20));
528 assert_eq!(profile.symbol_weights.get(&Symbol::Pi), Some(&25));
529 }
530
531 #[test]
532 fn test_profile_merge() {
533 let mut p1 = Profile::new();
534 p1.constants.push(UserConstant {
535 weight: 4,
536 name: "a".to_string(),
537 description: "First".to_string(),
538 value: 1.0,
539 num_type: NumType::Integer,
540 });
541
542 let mut p2 = Profile::new();
543 p2.constants.push(UserConstant {
544 weight: 5,
545 name: "b".to_string(),
546 description: "Second".to_string(),
547 value: 2.0,
548 num_type: NumType::Integer,
549 });
550 p2.symbol_names.insert(Symbol::Pi, "π".to_string());
551
552 let merged = p1.merge(p2);
553
554 assert_eq!(merged.constants.len(), 2);
555 assert_eq!(merged.symbol_names.len(), 1);
556 }
557}