Skip to main content

datui_lib/
template.rs

1use color_eyre::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::hash_map::DefaultHasher;
4use std::collections::HashSet;
5use std::fs;
6use std::hash::{Hash, Hasher};
7use std::io::Write;
8use std::path::{Path, PathBuf};
9use std::time::SystemTime;
10
11use polars::prelude::Schema;
12
13use crate::config::ConfigManager;
14use crate::filter_modal::FilterStatement;
15use crate::pivot_melt_modal::{MeltSpec, PivotSpec};
16
17// Custom serialization for SystemTime (convert to/from seconds since epoch)
18mod time_serde {
19    use serde::{Deserialize, Deserializer, Serialize, Serializer};
20    use std::time::{SystemTime, UNIX_EPOCH};
21
22    pub fn serialize<S>(time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
23    where
24        S: Serializer,
25    {
26        let duration = time.duration_since(UNIX_EPOCH).map_err(|e| {
27            serde::ser::Error::custom(format!("Failed to serialize SystemTime: {}", e))
28        })?;
29        duration.as_secs().serialize(serializer)
30    }
31
32    pub fn deserialize<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
33    where
34        D: Deserializer<'de>,
35    {
36        let secs = u64::deserialize(deserializer)?;
37        Ok(UNIX_EPOCH + std::time::Duration::from_secs(secs))
38    }
39
40    pub mod option {
41        use super::*;
42
43        pub fn serialize<S>(time: &Option<SystemTime>, serializer: S) -> Result<S::Ok, S::Error>
44        where
45            S: Serializer,
46        {
47            match time {
48                Some(time) => super::serialize(time, serializer),
49                None => serializer.serialize_none(),
50            }
51        }
52
53        pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<SystemTime>, D::Error>
54        where
55            D: Deserializer<'de>,
56        {
57            Option::<u64>::deserialize(deserializer)?
58                .map(|secs| Ok(UNIX_EPOCH + std::time::Duration::from_secs(secs)))
59                .transpose()
60        }
61    }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct Template {
66    pub id: String,
67    pub name: String,
68    pub description: Option<String>,
69    #[serde(with = "time_serde")]
70    pub created: SystemTime,
71    #[serde(with = "time_serde::option")]
72    #[serde(skip_serializing_if = "Option::is_none")]
73    #[serde(default)]
74    pub last_used: Option<SystemTime>,
75    pub usage_count: usize,
76    #[serde(skip_serializing_if = "Option::is_none")]
77    pub last_matched_file: Option<PathBuf>,
78    pub match_criteria: MatchCriteria,
79    pub settings: TemplateSettings,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct MatchCriteria {
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub exact_path: Option<PathBuf>,
86    #[serde(skip_serializing_if = "Option::is_none")]
87    pub relative_path: Option<String>,
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub path_pattern: Option<String>,
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub filename_pattern: Option<String>,
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub schema_columns: Option<Vec<String>>,
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub schema_types: Option<Vec<String>>,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct TemplateSettings {
100    #[serde(skip_serializing_if = "Option::is_none")]
101    pub query: Option<String>,
102    #[serde(skip_serializing_if = "Option::is_none")]
103    #[serde(default)]
104    pub sql_query: Option<String>,
105    #[serde(skip_serializing_if = "Option::is_none")]
106    #[serde(default)]
107    pub fuzzy_query: Option<String>,
108    pub filters: Vec<FilterStatement>,
109    pub sort_columns: Vec<String>,
110    pub sort_ascending: bool,
111    pub column_order: Vec<String>,
112    pub locked_columns_count: usize,
113    #[serde(skip_serializing_if = "Option::is_none")]
114    #[serde(default)]
115    pub pivot: Option<PivotSpec>,
116    #[serde(skip_serializing_if = "Option::is_none")]
117    #[serde(default)]
118    pub melt: Option<MeltSpec>,
119}
120
121pub struct TemplateManager {
122    config: ConfigManager,
123    templates: Vec<Template>,
124    pub(crate) templates_dir: PathBuf,
125}
126
127impl TemplateManager {
128    /// Creates a template manager that loads templates from disk. Use `empty()` when
129    /// config dirs are unavailable to avoid panicking on startup.
130    pub fn new(config: &ConfigManager) -> Result<Self> {
131        // Don't create directories on startup - be sensitive to constrained environments
132        // Directories will be created lazily when actually needed (e.g., saving templates)
133        let templates_dir = config.config_dir().join("templates");
134
135        let mut manager = Self {
136            config: config.clone(),
137            templates: Vec::new(),
138            templates_dir,
139        };
140
141        // Only try to load templates if the directory exists
142        // Don't create it if it doesn't exist
143        manager.load_templates()?;
144        Ok(manager)
145    }
146
147    /// Creates an empty in-memory template manager (no disk load). Use when
148    /// `new()` fails so the app can start without panicking; save may fail later.
149    pub fn empty(config: &ConfigManager) -> Self {
150        Self {
151            config: config.clone(),
152            templates: Vec::new(),
153            templates_dir: config.config_dir().join("templates"),
154        }
155    }
156
157    pub fn load_templates(&mut self) -> Result<()> {
158        self.templates.clear();
159
160        // Load all template files
161        if !self.templates_dir.exists() {
162            return Ok(());
163        }
164
165        let entries = fs::read_dir(&self.templates_dir)?;
166        for entry in entries {
167            let entry = entry?;
168            let path = entry.path();
169
170            if path.is_file() && path.extension().and_then(|s| s.to_str()) == Some("json") {
171                if let Ok(content) = fs::read_to_string(&path) {
172                    match serde_json::from_str::<Template>(&content) {
173                        Ok(template) => {
174                            self.templates.push(template);
175                        }
176                        Err(e) => {
177                            eprintln!("Warning: Could not parse template file {:?}: {}", path, e);
178                        }
179                    }
180                }
181            }
182        }
183
184        Ok(())
185    }
186
187    pub fn save_template(&self, template: &Template) -> Result<()> {
188        // Ensure config directory exists first
189        self.config.ensure_config_dir()?;
190
191        // Always ensure templates directory exists before writing
192        // Don't rely on existence checks - always create if needed
193        // This handles cases where the directory might have been deleted
194        // or where tests run in environments with different file system behavior
195        fs::create_dir_all(&self.templates_dir)?;
196
197        let filename = format!("template_{}.json", template.id);
198        let file_path = self.templates_dir.join(filename);
199
200        // Ensure the parent directory exists right before opening the file
201        // Double-check for robustness, especially in CI/test environments
202        if let Some(parent) = file_path.parent() {
203            fs::create_dir_all(parent)?;
204        }
205
206        let json = serde_json::to_string_pretty(template)?;
207
208        // Use file locking to prevent race conditions
209        use fs2::FileExt;
210        let mut file = fs::OpenOptions::new()
211            .create(true)
212            .write(true)
213            .truncate(true)
214            .open(&file_path)?;
215
216        file.lock_exclusive()?;
217        file.write_all(json.as_bytes())?;
218        file.flush()?;
219        file.unlock()?;
220
221        Ok(())
222    }
223
224    pub fn delete_template(&mut self, id: &str) -> Result<()> {
225        let filename = format!("template_{}.json", id);
226        let file_path = self.templates_dir.join(filename);
227
228        if file_path.exists() {
229            fs::remove_file(&file_path)?;
230        }
231
232        self.templates.retain(|t| t.id != id);
233        Ok(())
234    }
235
236    pub fn find_relevant_templates(
237        &self,
238        file_path: &Path,
239        schema: &Schema,
240    ) -> Vec<(Template, f64)> {
241        let mut results: Vec<(Template, f64)> = self
242            .templates
243            .iter()
244            .map(|template| {
245                let score = calculate_relevance(template, file_path, schema);
246                (template.clone(), score)
247            })
248            .collect();
249
250        // Sort by relevance score (highest first)
251        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
252
253        results
254    }
255
256    pub fn get_most_relevant(&self, file_path: &Path, schema: &Schema) -> Option<Template> {
257        self.find_relevant_templates(file_path, schema)
258            .into_iter()
259            .next()
260            .map(|(template, _)| template)
261    }
262
263    pub fn generate_next_template_name(&self) -> String {
264        let mut max_num = 0;
265
266        for template in &self.templates {
267            if template.name.starts_with("template") {
268                if let Some(num_str) = template.name.strip_prefix("template") {
269                    if let Ok(num) = num_str.parse::<u32>() {
270                        max_num = max_num.max(num);
271                    }
272                }
273            }
274        }
275
276        format!("template{:04}", max_num + 1)
277    }
278
279    pub fn template_exists(&self, name: &str) -> bool {
280        self.templates.iter().any(|t| t.name == name)
281    }
282
283    pub fn get_template_by_name(&self, name: &str) -> Option<&Template> {
284        self.templates.iter().find(|t| t.name == name)
285    }
286
287    pub fn get_template_by_id(&self, id: &str) -> Option<&Template> {
288        self.templates.iter().find(|t| t.id == id)
289    }
290
291    pub fn all_templates(&self) -> &[Template] {
292        &self.templates
293    }
294
295    pub fn create_template(
296        &mut self,
297        name: String,
298        description: Option<String>,
299        match_criteria: MatchCriteria,
300        settings: TemplateSettings,
301    ) -> Result<Template> {
302        // Generate unique ID based on name and timestamp
303        let mut hasher = DefaultHasher::new();
304        name.hash(&mut hasher);
305        SystemTime::now()
306            .duration_since(SystemTime::UNIX_EPOCH)
307            .unwrap_or_default()
308            .as_secs()
309            .hash(&mut hasher);
310        let id = format!("{:016x}", hasher.finish());
311
312        let template = Template {
313            id,
314            name,
315            description,
316            created: SystemTime::now(),
317            last_used: None,
318            usage_count: 0,
319            last_matched_file: None,
320            match_criteria,
321            settings,
322        };
323
324        // Save the template
325        self.save_template(&template)?;
326
327        // Reload templates to include the new one
328        self.load_templates()?;
329
330        Ok(template)
331    }
332
333    pub fn update_template(&mut self, template: &Template) -> Result<()> {
334        // Save the updated template
335        self.save_template(template)?;
336
337        // Update in-memory list
338        if let Some(existing) = self.templates.iter_mut().find(|t| t.id == template.id) {
339            *existing = template.clone();
340        } else {
341            // If not found, add it (shouldn't happen, but handle gracefully)
342            self.templates.push(template.clone());
343        }
344
345        Ok(())
346    }
347
348    pub fn remove_all_templates(&mut self) -> Result<()> {
349        // Delete all template files
350        if self.templates_dir.exists() {
351            for entry in fs::read_dir(&self.templates_dir)? {
352                let entry = entry?;
353                let path = entry.path();
354                if path.is_file()
355                    && path
356                        .file_name()
357                        .and_then(|n| n.to_str())
358                        .map(|s| s.starts_with("template_") && s.ends_with(".json"))
359                        .unwrap_or(false)
360                {
361                    fs::remove_file(&path)?;
362                }
363            }
364        }
365
366        // Clear in-memory list
367        self.templates.clear();
368
369        Ok(())
370    }
371}
372
373fn calculate_relevance(template: &Template, file_path: &Path, schema: &Schema) -> f64 {
374    let mut score = 0.0;
375
376    // Check exact path (absolute) match
377    let exact_path_match = template
378        .match_criteria
379        .exact_path
380        .as_ref()
381        .map(|exact| exact == file_path)
382        .unwrap_or(false);
383
384    // Check relative path match
385    let relative_path_match = if let Some(relative_path) = &template.match_criteria.relative_path {
386        // Calculate relative path from current working directory
387        if let Ok(cwd) = std::env::current_dir() {
388            if let Ok(rel_path) = file_path.strip_prefix(&cwd) {
389                rel_path.to_string_lossy() == *relative_path
390            } else {
391                false
392            }
393        } else {
394            false
395        }
396    } else {
397        false
398    };
399
400    // Check for exact schema match
401    let exact_schema_match = if let Some(required_cols) = &template.match_criteria.schema_columns {
402        let file_cols: HashSet<&str> = schema.iter_names().map(|s| s.as_str()).collect();
403        let required_cols_set: HashSet<&str> = required_cols.iter().map(|s| s.as_str()).collect();
404
405        // All required columns present AND no extra columns (exact match)
406        required_cols_set.is_subset(&file_cols) && file_cols.len() == required_cols_set.len()
407    } else {
408        false
409    };
410
411    // Exact path (absolute) + exact schema: highest priority (2000 points)
412    if exact_path_match && exact_schema_match {
413        return 2000.0;
414    }
415
416    // Exact path (absolute) only: very high priority (1000 points)
417    if exact_path_match {
418        return 1000.0;
419    }
420
421    // Relative path + exact schema: very high priority (1950 points)
422    if relative_path_match && exact_schema_match {
423        return 1950.0;
424    }
425
426    // Relative path only: very high priority (950 points)
427    if relative_path_match {
428        return 950.0;
429    }
430
431    // Exact schema only (without path matches): very high priority (900 points)
432    if exact_schema_match {
433        return 900.0;
434    }
435
436    // For non-exact matches, sum components
437    // Path pattern match
438    if let Some(pattern) = &template.match_criteria.path_pattern {
439        if matches_pattern(file_path.to_str().unwrap_or(""), pattern) {
440            score += 50.0;
441            score += pattern_specificity_bonus(pattern);
442        }
443    }
444
445    // Filename pattern match
446    if let Some(pattern) = &template.match_criteria.filename_pattern {
447        if let Some(filename) = file_path.file_name() {
448            if let Some(filename_str) = filename.to_str() {
449                if matches_pattern(filename_str, pattern) {
450                    score += 30.0;
451                    score += pattern_specificity_bonus(pattern);
452                }
453            }
454        }
455    }
456
457    // Partial schema matching (only if not exact match)
458    if let Some(required_cols) = &template.match_criteria.schema_columns {
459        let file_cols: HashSet<&str> = schema.iter_names().map(|s| s.as_str()).collect();
460        let matching_count = required_cols
461            .iter()
462            .filter(|col| file_cols.contains(col.as_str()))
463            .count();
464        score += (matching_count as f64) * 2.0; // 2 points per matching column
465
466        // Optional: type matching bonus (if types are specified)
467        // This would require comparing schema types, which is more complex
468    }
469
470    // Usage statistics
471    score += (template.usage_count.min(10) as f64) * 1.0;
472    if let Some(last_used) = template.last_used {
473        if let Ok(duration) = SystemTime::now().duration_since(last_used) {
474            let days_since = duration.as_secs() / 86400;
475            if days_since <= 7 {
476                score += 5.0;
477            } else if days_since <= 30 {
478                score += 2.0;
479            }
480        }
481    }
482
483    // Age penalty
484    if let Ok(duration) = SystemTime::now().duration_since(template.created) {
485        let months_old = (duration.as_secs() / (30 * 86400)) as f64;
486        score -= months_old * 1.0;
487    }
488
489    score
490}
491
492fn pattern_specificity_bonus(pattern: &str) -> f64 {
493    // More specific patterns (fewer wildcards) get higher bonuses
494    let wildcard_count = pattern.matches('*').count() + pattern.matches('?').count();
495    match wildcard_count {
496        0 => 10.0, // No wildcards (most specific)
497        1 => 5.0,  // One wildcard
498        2 => 3.0,  // Two wildcards
499        3 => 1.0,  // Three wildcards
500        _ => 0.0,  // Many wildcards (less specific)
501    }
502}
503
504fn matches_pattern(text: &str, pattern: &str) -> bool {
505    // Simple glob-like pattern matching
506    // Convert pattern to regex-like matching
507    // Support: * (matches any sequence), ? (matches single char)
508
509    // Simple implementation: convert * to .* and ? to . for regex
510    let mut regex_pattern = String::new();
511    for ch in pattern.chars() {
512        match ch {
513            '*' => regex_pattern.push_str(".*"),
514            '?' => regex_pattern.push('.'),
515            '.' | '(' | ')' | '[' | ']' | '{' | '}' | '\\' | '^' | '$' | '+' => {
516                regex_pattern.push('\\');
517                regex_pattern.push(ch);
518            }
519            _ => regex_pattern.push(ch),
520        }
521    }
522
523    // Use simple string matching for now (full regex would require regex crate)
524    // For simple cases: * matches anything, exact match otherwise
525    if pattern == "*" {
526        return true;
527    }
528
529    // Simple wildcard matching
530    let pattern_parts: Vec<&str> = pattern.split('*').collect();
531
532    if pattern_parts.len() == 1 {
533        // No wildcards, exact match
534        return text == pattern;
535    }
536
537    // Has wildcards - check if text matches pattern parts
538    let mut text_pos = 0;
539    for (i, part) in pattern_parts.iter().enumerate() {
540        if part.is_empty() {
541            continue;
542        }
543
544        if i == 0 {
545            // First part must match start
546            if !text.starts_with(part) {
547                return false;
548            }
549            text_pos = part.len();
550        } else if i == pattern_parts.len() - 1 {
551            // Last part must match end
552            return text[text_pos..].ends_with(part);
553        } else {
554            // Middle parts must appear in order
555            if let Some(pos) = text[text_pos..].find(part) {
556                text_pos += pos + part.len();
557            } else {
558                return false;
559            }
560        }
561    }
562
563    true
564}
565
566#[cfg(test)]
567mod tests {
568    use super::*;
569
570    /// Old template JSON without sql_query or fuzzy_query deserializes; those fields default to None.
571    #[test]
572    fn test_settings_deserialize_without_sql_fuzzy() {
573        let json = r#"{
574            "query": "select a",
575            "filters": [],
576            "sort_columns": [],
577            "sort_ascending": true,
578            "column_order": ["a", "b"],
579            "locked_columns_count": 0
580        }"#;
581        let settings: TemplateSettings = serde_json::from_str(json).unwrap();
582        assert_eq!(settings.query, Some("select a".to_string()));
583        assert_eq!(settings.sql_query, None);
584        assert_eq!(settings.fuzzy_query, None);
585    }
586
587    #[test]
588    fn test_matches_pattern() {
589        assert!(matches_pattern("test.csv", "test.csv"));
590        assert!(matches_pattern("test.csv", "*.csv"));
591        assert!(matches_pattern("sales_2024.csv", "sales_*.csv"));
592        assert!(matches_pattern(
593            "/data/reports/sales.csv",
594            "/data/reports/*.csv"
595        ));
596        assert!(!matches_pattern("test.txt", "*.csv"));
597        assert!(!matches_pattern("sales.csv", "sales_*.csv"));
598    }
599
600    #[test]
601    fn test_pattern_specificity_bonus() {
602        assert_eq!(pattern_specificity_bonus("test.csv"), 10.0);
603        assert_eq!(pattern_specificity_bonus("*.csv"), 5.0);
604        assert_eq!(pattern_specificity_bonus("sales_*.csv"), 5.0);
605    }
606}