Skip to main content

ries_rs/
profile.rs

1//! Profile file support for RIES configuration
2//!
3//! Parse and load `.ries` profile files for custom configuration including
4//! user-defined constants, user-defined functions, symbol names, and symbol weights.
5
6use std::collections::HashMap;
7use std::fs;
8use std::io::{self, BufRead};
9use std::path::{Path, PathBuf};
10
11use crate::symbol::{NumType, Symbol};
12
13// Re-export UserFunction for convenience
14pub use crate::udf::UserFunction;
15
16/// A user-defined constant
17#[derive(Clone, Debug)]
18pub struct UserConstant {
19    /// Weight (complexity) of this constant
20    ///
21    /// This field is part of the public API and is used when generating expressions
22    /// that include user-defined constants.
23    #[allow(dead_code)]
24    pub weight: u32,
25    /// Short name (single character)
26    pub name: String,
27    /// Description (for display)
28    ///
29    /// This field is part of the public API for documentation and display purposes.
30    #[allow(dead_code)]
31    pub description: String,
32    /// Numeric value
33    pub value: f64,
34    /// Numeric type classification
35    pub num_type: NumType,
36}
37
38/// Parsed profile configuration
39#[derive(Clone, Debug, Default)]
40pub struct Profile {
41    /// User-defined constants
42    pub constants: Vec<UserConstant>,
43    /// User-defined functions
44    pub functions: Vec<UserFunction>,
45    /// Custom symbol names (e.g., :p:π)
46    pub symbol_names: HashMap<Symbol, String>,
47    /// Custom symbol weights
48    pub symbol_weights: HashMap<Symbol, u32>,
49    /// Additional profile files to include
50    pub includes: Vec<PathBuf>,
51}
52
53impl Profile {
54    /// Create an empty profile
55    pub fn new() -> Self {
56        Self::default()
57    }
58
59    /// Load a profile from a file
60    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, ProfileError> {
61        load_profile_recursive(path.as_ref(), &mut Vec::new(), 0)
62    }
63
64    /// Load the default profile chain (~/.ries_profile, ./.ries)
65    pub fn load_default() -> Self {
66        let mut profile = Profile::new();
67
68        // Try to load from home directory
69        if let Some(home) = dirs::home_dir() {
70            let home_profile = home.join(".ries_profile");
71            if home_profile.exists() {
72                if let Ok(p) = Self::from_file(&home_profile) {
73                    profile = profile.merge(p);
74                }
75            }
76        }
77
78        // Try to load from current directory
79        let local_profile = PathBuf::from(".ries");
80        if local_profile.exists() {
81            if let Ok(p) = Self::from_file(&local_profile) {
82                profile = profile.merge(p);
83            }
84        }
85
86        profile
87    }
88
89    /// Add a validated user constant to this profile.
90    ///
91    /// This method centralizes validation logic to ensure consistent
92    /// handling of user constants across CLI and profile file parsing.
93    ///
94    /// # Arguments
95    ///
96    /// * `weight` - Complexity weight for this constant
97    /// * `name` - Short name (single character preferred)
98    /// * `description` - Human-readable description
99    /// * `value` - Numeric value
100    ///
101    /// # Errors
102    ///
103    /// Returns an error if:
104    /// * `name` is empty
105    /// * `value` is not finite (NaN or infinity)
106    pub fn add_constant(
107        &mut self,
108        weight: u32,
109        name: String,
110        description: String,
111        value: f64,
112    ) -> Result<(), ProfileError> {
113        // Validate name
114        if name.is_empty() {
115            return Err(ProfileError::ValidationError(
116                "Constant name cannot be empty".to_string(),
117            ));
118        }
119
120        // Validate value is finite
121        if !value.is_finite() {
122            return Err(ProfileError::ValidationError(format!(
123                "Constant value must be finite (got {})",
124                value
125            )));
126        }
127
128        // Determine numeric type based on value characteristics
129        let num_type = if value.fract() == 0.0 && value.abs() < 1e10 {
130            NumType::Integer
131        } else if is_rational(value) {
132            NumType::Rational
133        } else {
134            NumType::Transcendental
135        };
136
137        self.constants.push(UserConstant {
138            weight,
139            name,
140            description,
141            value,
142            num_type,
143        });
144
145        Ok(())
146    }
147
148    /// Merge another profile into this one (other takes precedence)
149    pub fn merge(mut self, other: Profile) -> Self {
150        // Merge constants (append, later ones override by name)
151        for c in other.constants {
152            // Remove existing constant with same name
153            self.constants.retain(|existing| existing.name != c.name);
154            self.constants.push(c);
155        }
156
157        // Merge functions (append, later ones override by name)
158        for f in other.functions {
159            // Remove existing function with same name
160            self.functions.retain(|existing| existing.name != f.name);
161            self.functions.push(f);
162        }
163
164        // Merge symbol names
165        self.symbol_names.extend(other.symbol_names);
166
167        // Merge symbol weights
168        self.symbol_weights.extend(other.symbol_weights);
169
170        // Merge includes
171        self.includes.extend(other.includes);
172
173        self
174    }
175}
176
177const MAX_INCLUDE_DEPTH: usize = 25;
178
179fn load_profile_recursive(
180    path: &Path,
181    include_stack: &mut Vec<PathBuf>,
182    depth: usize,
183) -> Result<Profile, ProfileError> {
184    if depth > MAX_INCLUDE_DEPTH {
185        return Err(ProfileError::ParseError(
186            path.to_path_buf(),
187            0,
188            format!("Profile include depth exceeded {}", MAX_INCLUDE_DEPTH),
189        ));
190    }
191
192    let canonical = fs::canonicalize(path).unwrap_or_else(|_| path.to_path_buf());
193    if include_stack.contains(&canonical) {
194        return Err(ProfileError::ParseError(
195            path.to_path_buf(),
196            0,
197            "Recursive --include detected".to_string(),
198        ));
199    }
200    include_stack.push(canonical);
201
202    let file = fs::File::open(path).map_err(|e| ProfileError::IoError(path.to_path_buf(), e))?;
203    let mut profile = Profile::new();
204    let reader = io::BufReader::new(file);
205
206    for (line_num, line_result) in reader.lines().enumerate() {
207        let line = line_result.map_err(|e| ProfileError::IoError(path.to_path_buf(), e))?;
208        let trimmed = line.trim();
209        if trimmed.is_empty() || trimmed.starts_with('#') {
210            continue;
211        }
212
213        if trimmed.starts_with("--include") {
214            let include_raw = parse_include_path(trimmed)
215                .map_err(|e| ProfileError::ParseError(path.to_path_buf(), line_num + 1, e))?;
216
217            let include_resolved = resolve_include_path(path, &include_raw).ok_or_else(|| {
218                ProfileError::ParseError(
219                    path.to_path_buf(),
220                    line_num + 1,
221                    format!(
222                        "Could not open '{}' or '{}.ries' for reading",
223                        include_raw.display(),
224                        include_raw.display()
225                    ),
226                )
227            })?;
228
229            profile.includes.push(include_resolved.clone());
230            let nested = load_profile_recursive(&include_resolved, include_stack, depth + 1)?;
231            profile = profile.merge(nested);
232            continue;
233        }
234
235        if let Err(e) = parse_profile_line(&mut profile, trimmed) {
236            return Err(ProfileError::ParseError(
237                path.to_path_buf(),
238                line_num + 1,
239                e,
240            ));
241        }
242    }
243
244    include_stack.pop();
245    Ok(profile)
246}
247
248fn resolve_include_path(current_file: &Path, include_path: &Path) -> Option<PathBuf> {
249    let mut candidates = Vec::new();
250
251    if include_path.is_absolute() {
252        candidates.push(include_path.to_path_buf());
253    } else {
254        let base = current_file.parent().unwrap_or_else(|| Path::new("."));
255        candidates.push(base.join(include_path));
256    }
257
258    let mut with_suffix = include_path.as_os_str().to_os_string();
259    with_suffix.push(".ries");
260    if include_path.is_absolute() {
261        candidates.push(PathBuf::from(with_suffix));
262    } else {
263        let base = current_file.parent().unwrap_or_else(|| Path::new("."));
264        candidates.push(base.join(PathBuf::from(with_suffix)));
265    }
266
267    candidates.into_iter().find(|p| p.exists())
268}
269
270/// Parse a single profile line
271fn parse_profile_line(profile: &mut Profile, line: &str) -> Result<(), String> {
272    // Handle -X (user constant) lines
273    if line.starts_with("-X") {
274        return parse_user_constant(profile, line);
275    }
276
277    // Handle --define (user function) lines
278    if line.starts_with("--define") {
279        return parse_user_function(profile, line);
280    }
281
282    // Handle --symbol-names
283    if line.starts_with("--symbol-names") {
284        return parse_symbol_names(profile, line);
285    }
286
287    // Handle --symbol-weights
288    if line.starts_with("--symbol-weights") {
289        return parse_symbol_weights(profile, line);
290    }
291
292    // Unknown directive - could be a comment or unsupported option
293    // For now, just ignore silently
294    Ok(())
295}
296
297/// Parse a user constant definition
298/// Format: -X "weight:name:description:value"
299fn parse_user_constant(profile: &mut Profile, line: &str) -> Result<(), String> {
300    // Extract the quoted part
301    let rest = line[2..].trim();
302
303    // Handle both quoted and unquoted formats
304    let content = if let Some(stripped) = rest.strip_prefix('"') {
305        // Quoted format: -X "weight:name:description:value"
306        let end_quote = stripped.find('"').ok_or("Unclosed quote in -X directive")?;
307        &stripped[..end_quote]
308    } else {
309        // Unquoted format: -X weight:name:description:value
310        rest
311    };
312
313    let parts: Vec<&str> = content.split(':').collect();
314    if parts.len() != 4 {
315        return Err(format!(
316            "Invalid -X format: expected 4 colon-separated parts, got {}",
317            parts.len()
318        ));
319    }
320
321    let weight: u32 = parts[0]
322        .parse()
323        .map_err(|_| format!("Invalid weight: {}", parts[0]))?;
324
325    let name = parts[1].to_string();
326    let description = parts[2].to_string();
327
328    let value: f64 = parts[3]
329        .parse()
330        .map_err(|_| format!("Invalid value: {}", parts[3]))?;
331
332    // Use Profile's centralized validation
333    profile
334        .add_constant(weight, name, description, value)
335        .map_err(|e| e.to_string())?;
336
337    Ok(())
338}
339
340/// Check if a value is likely rational (simple fraction)
341fn is_rational(v: f64) -> bool {
342    if !v.is_finite() || v == 0.0 {
343        return true;
344    }
345
346    // Check common denominators up to 100
347    for denom in 1..=100_u32 {
348        let numer = v * denom as f64;
349        if (numer.round() - numer).abs() < 1e-10 {
350            return true;
351        }
352    }
353    false
354}
355
356/// Parse a user function definition
357/// Format: --define "weight:name:description:formula"
358fn parse_user_function(profile: &mut Profile, line: &str) -> Result<(), String> {
359    // Extract the quoted part
360    let rest = line["--define".len()..].trim();
361
362    // Handle both quoted and unquoted formats
363    let content = if let Some(stripped) = rest.strip_prefix('"') {
364        // Quoted format: --define "weight:name:description:formula"
365        let end_quote = stripped
366            .find('"')
367            .ok_or("Unclosed quote in --define directive")?;
368        &stripped[..end_quote]
369    } else {
370        // Unquoted format: --define weight:name:description:formula
371        rest
372    };
373
374    // Parse the function using UserFunction::parse
375    let udf = UserFunction::parse(content)?;
376    profile.functions.push(udf);
377
378    Ok(())
379}
380
381/// Parse symbol names directive
382/// Format: --symbol-names :p:π :e:ℯ :f:φ
383fn parse_symbol_names(profile: &mut Profile, line: &str) -> Result<(), String> {
384    let rest = line["--symbol-names".len()..].trim();
385
386    for part in rest.split_whitespace() {
387        if !part.starts_with(':') {
388            continue;
389        }
390
391        let inner = &part[1..];
392        if let Some(colon_pos) = inner.find(':') {
393            let symbol_char = inner[..colon_pos]
394                .chars()
395                .next()
396                .ok_or("Empty symbol in --symbol-names")?;
397            let name = inner[colon_pos + 1..].to_string();
398
399            if let Some(symbol) = Symbol::from_byte(symbol_char as u8) {
400                profile.symbol_names.insert(symbol, name);
401            }
402        }
403    }
404
405    Ok(())
406}
407
408/// Parse symbol weights directive
409/// Format: --symbol-weights :W:20 :p:25
410fn parse_symbol_weights(profile: &mut Profile, line: &str) -> Result<(), String> {
411    let rest = line["--symbol-weights".len()..].trim();
412
413    for part in rest.split_whitespace() {
414        if !part.starts_with(':') {
415            continue;
416        }
417
418        let inner = &part[1..];
419        if let Some(colon_pos) = inner.find(':') {
420            let symbol_char = inner[..colon_pos]
421                .chars()
422                .next()
423                .ok_or("Empty symbol in --symbol-weights")?;
424            let weight: u32 = inner[colon_pos + 1..]
425                .parse()
426                .map_err(|_| format!("Invalid weight in --symbol-weights: {}", inner))?;
427
428            if let Some(symbol) = Symbol::from_byte(symbol_char as u8) {
429                profile.symbol_weights.insert(symbol, weight);
430            }
431        }
432    }
433
434    Ok(())
435}
436
437/// Parse include directive path.
438/// Format: --include /path/to/profile.ries
439fn parse_include_path(line: &str) -> Result<PathBuf, String> {
440    let rest = line["--include".len()..].trim();
441
442    if rest.is_empty() {
443        return Err("--include requires a filename".to_string());
444    }
445
446    // Remove quotes if present
447    let path_str = if rest.starts_with('"') && rest.ends_with('"') {
448        &rest[1..rest.len() - 1]
449    } else {
450        rest
451    };
452
453    Ok(PathBuf::from(path_str))
454}
455
456/// Errors that can occur during profile loading
457#[derive(Debug)]
458pub enum ProfileError {
459    /// I/O error reading file
460    IoError(PathBuf, io::Error),
461    /// Parse error at specific line
462    ParseError(PathBuf, usize, String),
463    /// Validation error (e.g., invalid constant value)
464    ValidationError(String),
465}
466
467impl std::fmt::Display for ProfileError {
468    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469        match self {
470            ProfileError::IoError(path, e) => {
471                write!(f, "Error reading {}: {}", path.display(), e)
472            }
473            ProfileError::ParseError(path, line, msg) => {
474                write!(
475                    f,
476                    "Parse error in {} at line {}: {}",
477                    path.display(),
478                    line,
479                    msg
480                )
481            }
482            ProfileError::ValidationError(msg) => {
483                write!(f, "Validation error: {}", msg)
484            }
485        }
486    }
487}
488
489impl std::error::Error for ProfileError {}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494
495    #[test]
496    fn test_parse_user_constant() {
497        let mut profile = Profile::new();
498        parse_user_constant(
499            &mut profile,
500            r#"-X "4:gamma:Euler's constant:0.5772156649""#,
501        )
502        .unwrap();
503
504        assert_eq!(profile.constants.len(), 1);
505        assert_eq!(profile.constants[0].name, "gamma");
506        assert_eq!(profile.constants[0].weight, 4);
507        assert!((profile.constants[0].value - 0.5772156649).abs() < 1e-10);
508    }
509
510    #[test]
511    fn test_parse_symbol_names() {
512        let mut profile = Profile::new();
513        parse_symbol_names(&mut profile, "--symbol-names :p:π :e:ℯ").unwrap();
514
515        assert_eq!(
516            profile.symbol_names.get(&Symbol::Pi),
517            Some(&"π".to_string())
518        );
519        assert_eq!(profile.symbol_names.get(&Symbol::E), Some(&"ℯ".to_string()));
520    }
521
522    #[test]
523    fn test_parse_symbol_weights() {
524        let mut profile = Profile::new();
525        parse_symbol_weights(&mut profile, "--symbol-weights :W:20 :p:25").unwrap();
526
527        assert_eq!(profile.symbol_weights.get(&Symbol::LambertW), Some(&20));
528        assert_eq!(profile.symbol_weights.get(&Symbol::Pi), Some(&25));
529    }
530
531    #[test]
532    fn test_profile_merge() {
533        let mut p1 = Profile::new();
534        p1.constants.push(UserConstant {
535            weight: 4,
536            name: "a".to_string(),
537            description: "First".to_string(),
538            value: 1.0,
539            num_type: NumType::Integer,
540        });
541
542        let mut p2 = Profile::new();
543        p2.constants.push(UserConstant {
544            weight: 5,
545            name: "b".to_string(),
546            description: "Second".to_string(),
547            value: 2.0,
548            num_type: NumType::Integer,
549        });
550        p2.symbol_names.insert(Symbol::Pi, "π".to_string());
551
552        let merged = p1.merge(p2);
553
554        assert_eq!(merged.constants.len(), 2);
555        assert_eq!(merged.symbol_names.len(), 1);
556    }
557}