1use color_eyre::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::hash_map::DefaultHasher;
4use std::collections::HashSet;
5use std::fs;
6use std::hash::{Hash, Hasher};
7use std::io::Write;
8use std::path::{Path, PathBuf};
9use std::time::SystemTime;
10
11use polars::prelude::Schema;
12
13use crate::config::ConfigManager;
14use crate::filter_modal::FilterStatement;
15use crate::pivot_melt_modal::{MeltSpec, PivotSpec};
16
17mod time_serde {
19 use serde::{Deserialize, Deserializer, Serialize, Serializer};
20 use std::time::{SystemTime, UNIX_EPOCH};
21
22 pub fn serialize<S>(time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
23 where
24 S: Serializer,
25 {
26 let duration = time.duration_since(UNIX_EPOCH).map_err(|e| {
27 serde::ser::Error::custom(format!("Failed to serialize SystemTime: {}", e))
28 })?;
29 duration.as_secs().serialize(serializer)
30 }
31
32 pub fn deserialize<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
33 where
34 D: Deserializer<'de>,
35 {
36 let secs = u64::deserialize(deserializer)?;
37 Ok(UNIX_EPOCH + std::time::Duration::from_secs(secs))
38 }
39
40 pub mod option {
41 use super::*;
42
43 pub fn serialize<S>(time: &Option<SystemTime>, serializer: S) -> Result<S::Ok, S::Error>
44 where
45 S: Serializer,
46 {
47 match time {
48 Some(time) => super::serialize(time, serializer),
49 None => serializer.serialize_none(),
50 }
51 }
52
53 pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<SystemTime>, D::Error>
54 where
55 D: Deserializer<'de>,
56 {
57 Option::<u64>::deserialize(deserializer)?
58 .map(|secs| Ok(UNIX_EPOCH + std::time::Duration::from_secs(secs)))
59 .transpose()
60 }
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct Template {
66 pub id: String,
67 pub name: String,
68 pub description: Option<String>,
69 #[serde(with = "time_serde")]
70 pub created: SystemTime,
71 #[serde(with = "time_serde::option")]
72 #[serde(skip_serializing_if = "Option::is_none")]
73 #[serde(default)]
74 pub last_used: Option<SystemTime>,
75 pub usage_count: usize,
76 #[serde(skip_serializing_if = "Option::is_none")]
77 pub last_matched_file: Option<PathBuf>,
78 pub match_criteria: MatchCriteria,
79 pub settings: TemplateSettings,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct MatchCriteria {
84 #[serde(skip_serializing_if = "Option::is_none")]
85 pub exact_path: Option<PathBuf>,
86 #[serde(skip_serializing_if = "Option::is_none")]
87 pub relative_path: Option<String>,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 pub path_pattern: Option<String>,
90 #[serde(skip_serializing_if = "Option::is_none")]
91 pub filename_pattern: Option<String>,
92 #[serde(skip_serializing_if = "Option::is_none")]
93 pub schema_columns: Option<Vec<String>>,
94 #[serde(skip_serializing_if = "Option::is_none")]
95 pub schema_types: Option<Vec<String>>,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct TemplateSettings {
100 #[serde(skip_serializing_if = "Option::is_none")]
101 pub query: Option<String>,
102 #[serde(skip_serializing_if = "Option::is_none")]
103 #[serde(default)]
104 pub sql_query: Option<String>,
105 #[serde(skip_serializing_if = "Option::is_none")]
106 #[serde(default)]
107 pub fuzzy_query: Option<String>,
108 pub filters: Vec<FilterStatement>,
109 pub sort_columns: Vec<String>,
110 pub sort_ascending: bool,
111 pub column_order: Vec<String>,
112 pub locked_columns_count: usize,
113 #[serde(skip_serializing_if = "Option::is_none")]
114 #[serde(default)]
115 pub pivot: Option<PivotSpec>,
116 #[serde(skip_serializing_if = "Option::is_none")]
117 #[serde(default)]
118 pub melt: Option<MeltSpec>,
119}
120
121pub struct TemplateManager {
122 config: ConfigManager,
123 templates: Vec<Template>,
124 pub(crate) templates_dir: PathBuf,
125}
126
127impl TemplateManager {
128 pub fn new(config: &ConfigManager) -> Result<Self> {
131 let templates_dir = config.config_dir().join("templates");
134
135 let mut manager = Self {
136 config: config.clone(),
137 templates: Vec::new(),
138 templates_dir,
139 };
140
141 manager.load_templates()?;
144 Ok(manager)
145 }
146
147 pub fn empty(config: &ConfigManager) -> Self {
150 Self {
151 config: config.clone(),
152 templates: Vec::new(),
153 templates_dir: config.config_dir().join("templates"),
154 }
155 }
156
157 pub fn load_templates(&mut self) -> Result<()> {
158 self.templates.clear();
159
160 if !self.templates_dir.exists() {
162 return Ok(());
163 }
164
165 let entries = fs::read_dir(&self.templates_dir)?;
166 for entry in entries {
167 let entry = entry?;
168 let path = entry.path();
169
170 if path.is_file() && path.extension().and_then(|s| s.to_str()) == Some("json") {
171 if let Ok(content) = fs::read_to_string(&path) {
172 match serde_json::from_str::<Template>(&content) {
173 Ok(template) => {
174 self.templates.push(template);
175 }
176 Err(e) => {
177 eprintln!("Warning: Could not parse template file {:?}: {}", path, e);
178 }
179 }
180 }
181 }
182 }
183
184 Ok(())
185 }
186
187 pub fn save_template(&self, template: &Template) -> Result<()> {
188 self.config.ensure_config_dir()?;
190
191 fs::create_dir_all(&self.templates_dir)?;
196
197 let filename = format!("template_{}.json", template.id);
198 let file_path = self.templates_dir.join(filename);
199
200 if let Some(parent) = file_path.parent() {
203 fs::create_dir_all(parent)?;
204 }
205
206 let json = serde_json::to_string_pretty(template)?;
207
208 use fs2::FileExt;
210 let mut file = fs::OpenOptions::new()
211 .create(true)
212 .write(true)
213 .truncate(true)
214 .open(&file_path)?;
215
216 file.lock_exclusive()?;
217 file.write_all(json.as_bytes())?;
218 file.flush()?;
219 file.unlock()?;
220
221 Ok(())
222 }
223
224 pub fn delete_template(&mut self, id: &str) -> Result<()> {
225 let filename = format!("template_{}.json", id);
226 let file_path = self.templates_dir.join(filename);
227
228 if file_path.exists() {
229 fs::remove_file(&file_path)?;
230 }
231
232 self.templates.retain(|t| t.id != id);
233 Ok(())
234 }
235
236 pub fn find_relevant_templates(
237 &self,
238 file_path: &Path,
239 schema: &Schema,
240 ) -> Vec<(Template, f64)> {
241 let mut results: Vec<(Template, f64)> = self
242 .templates
243 .iter()
244 .map(|template| {
245 let score = calculate_relevance(template, file_path, schema);
246 (template.clone(), score)
247 })
248 .collect();
249
250 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
252
253 results
254 }
255
256 pub fn get_most_relevant(&self, file_path: &Path, schema: &Schema) -> Option<Template> {
257 self.find_relevant_templates(file_path, schema)
258 .into_iter()
259 .next()
260 .map(|(template, _)| template)
261 }
262
263 pub fn generate_next_template_name(&self) -> String {
264 let mut max_num = 0;
265
266 for template in &self.templates {
267 if template.name.starts_with("template") {
268 if let Some(num_str) = template.name.strip_prefix("template") {
269 if let Ok(num) = num_str.parse::<u32>() {
270 max_num = max_num.max(num);
271 }
272 }
273 }
274 }
275
276 format!("template{:04}", max_num + 1)
277 }
278
279 pub fn template_exists(&self, name: &str) -> bool {
280 self.templates.iter().any(|t| t.name == name)
281 }
282
283 pub fn get_template_by_name(&self, name: &str) -> Option<&Template> {
284 self.templates.iter().find(|t| t.name == name)
285 }
286
287 pub fn get_template_by_id(&self, id: &str) -> Option<&Template> {
288 self.templates.iter().find(|t| t.id == id)
289 }
290
291 pub fn all_templates(&self) -> &[Template] {
292 &self.templates
293 }
294
295 pub fn create_template(
296 &mut self,
297 name: String,
298 description: Option<String>,
299 match_criteria: MatchCriteria,
300 settings: TemplateSettings,
301 ) -> Result<Template> {
302 let mut hasher = DefaultHasher::new();
304 name.hash(&mut hasher);
305 SystemTime::now()
306 .duration_since(SystemTime::UNIX_EPOCH)
307 .unwrap_or_default()
308 .as_secs()
309 .hash(&mut hasher);
310 let id = format!("{:016x}", hasher.finish());
311
312 let template = Template {
313 id,
314 name,
315 description,
316 created: SystemTime::now(),
317 last_used: None,
318 usage_count: 0,
319 last_matched_file: None,
320 match_criteria,
321 settings,
322 };
323
324 self.save_template(&template)?;
326
327 self.load_templates()?;
329
330 Ok(template)
331 }
332
333 pub fn update_template(&mut self, template: &Template) -> Result<()> {
334 self.save_template(template)?;
336
337 if let Some(existing) = self.templates.iter_mut().find(|t| t.id == template.id) {
339 *existing = template.clone();
340 } else {
341 self.templates.push(template.clone());
343 }
344
345 Ok(())
346 }
347
348 pub fn remove_all_templates(&mut self) -> Result<()> {
349 if self.templates_dir.exists() {
351 for entry in fs::read_dir(&self.templates_dir)? {
352 let entry = entry?;
353 let path = entry.path();
354 if path.is_file()
355 && path
356 .file_name()
357 .and_then(|n| n.to_str())
358 .map(|s| s.starts_with("template_") && s.ends_with(".json"))
359 .unwrap_or(false)
360 {
361 fs::remove_file(&path)?;
362 }
363 }
364 }
365
366 self.templates.clear();
368
369 Ok(())
370 }
371}
372
373fn calculate_relevance(template: &Template, file_path: &Path, schema: &Schema) -> f64 {
374 let mut score = 0.0;
375
376 let exact_path_match = template
378 .match_criteria
379 .exact_path
380 .as_ref()
381 .map(|exact| exact == file_path)
382 .unwrap_or(false);
383
384 let relative_path_match = if let Some(relative_path) = &template.match_criteria.relative_path {
386 if let Ok(cwd) = std::env::current_dir() {
388 if let Ok(rel_path) = file_path.strip_prefix(&cwd) {
389 rel_path.to_string_lossy() == *relative_path
390 } else {
391 false
392 }
393 } else {
394 false
395 }
396 } else {
397 false
398 };
399
400 let exact_schema_match = if let Some(required_cols) = &template.match_criteria.schema_columns {
402 let file_cols: HashSet<&str> = schema.iter_names().map(|s| s.as_str()).collect();
403 let required_cols_set: HashSet<&str> = required_cols.iter().map(|s| s.as_str()).collect();
404
405 required_cols_set.is_subset(&file_cols) && file_cols.len() == required_cols_set.len()
407 } else {
408 false
409 };
410
411 if exact_path_match && exact_schema_match {
413 return 2000.0;
414 }
415
416 if exact_path_match {
418 return 1000.0;
419 }
420
421 if relative_path_match && exact_schema_match {
423 return 1950.0;
424 }
425
426 if relative_path_match {
428 return 950.0;
429 }
430
431 if exact_schema_match {
433 return 900.0;
434 }
435
436 if let Some(pattern) = &template.match_criteria.path_pattern {
439 if matches_pattern(file_path.to_str().unwrap_or(""), pattern) {
440 score += 50.0;
441 score += pattern_specificity_bonus(pattern);
442 }
443 }
444
445 if let Some(pattern) = &template.match_criteria.filename_pattern {
447 if let Some(filename) = file_path.file_name() {
448 if let Some(filename_str) = filename.to_str() {
449 if matches_pattern(filename_str, pattern) {
450 score += 30.0;
451 score += pattern_specificity_bonus(pattern);
452 }
453 }
454 }
455 }
456
457 if let Some(required_cols) = &template.match_criteria.schema_columns {
459 let file_cols: HashSet<&str> = schema.iter_names().map(|s| s.as_str()).collect();
460 let matching_count = required_cols
461 .iter()
462 .filter(|col| file_cols.contains(col.as_str()))
463 .count();
464 score += (matching_count as f64) * 2.0; }
469
470 score += (template.usage_count.min(10) as f64) * 1.0;
472 if let Some(last_used) = template.last_used {
473 if let Ok(duration) = SystemTime::now().duration_since(last_used) {
474 let days_since = duration.as_secs() / 86400;
475 if days_since <= 7 {
476 score += 5.0;
477 } else if days_since <= 30 {
478 score += 2.0;
479 }
480 }
481 }
482
483 if let Ok(duration) = SystemTime::now().duration_since(template.created) {
485 let months_old = (duration.as_secs() / (30 * 86400)) as f64;
486 score -= months_old * 1.0;
487 }
488
489 score
490}
491
492fn pattern_specificity_bonus(pattern: &str) -> f64 {
493 let wildcard_count = pattern.matches('*').count() + pattern.matches('?').count();
495 match wildcard_count {
496 0 => 10.0, 1 => 5.0, 2 => 3.0, 3 => 1.0, _ => 0.0, }
502}
503
504fn matches_pattern(text: &str, pattern: &str) -> bool {
505 let mut regex_pattern = String::new();
511 for ch in pattern.chars() {
512 match ch {
513 '*' => regex_pattern.push_str(".*"),
514 '?' => regex_pattern.push('.'),
515 '.' | '(' | ')' | '[' | ']' | '{' | '}' | '\\' | '^' | '$' | '+' => {
516 regex_pattern.push('\\');
517 regex_pattern.push(ch);
518 }
519 _ => regex_pattern.push(ch),
520 }
521 }
522
523 if pattern == "*" {
526 return true;
527 }
528
529 let pattern_parts: Vec<&str> = pattern.split('*').collect();
531
532 if pattern_parts.len() == 1 {
533 return text == pattern;
535 }
536
537 let mut text_pos = 0;
539 for (i, part) in pattern_parts.iter().enumerate() {
540 if part.is_empty() {
541 continue;
542 }
543
544 if i == 0 {
545 if !text.starts_with(part) {
547 return false;
548 }
549 text_pos = part.len();
550 } else if i == pattern_parts.len() - 1 {
551 return text[text_pos..].ends_with(part);
553 } else {
554 if let Some(pos) = text[text_pos..].find(part) {
556 text_pos += pos + part.len();
557 } else {
558 return false;
559 }
560 }
561 }
562
563 true
564}
565
566#[cfg(test)]
567mod tests {
568 use super::*;
569
570 #[test]
572 fn test_settings_deserialize_without_sql_fuzzy() {
573 let json = r#"{
574 "query": "select a",
575 "filters": [],
576 "sort_columns": [],
577 "sort_ascending": true,
578 "column_order": ["a", "b"],
579 "locked_columns_count": 0
580 }"#;
581 let settings: TemplateSettings = serde_json::from_str(json).unwrap();
582 assert_eq!(settings.query, Some("select a".to_string()));
583 assert_eq!(settings.sql_query, None);
584 assert_eq!(settings.fuzzy_query, None);
585 }
586
587 #[test]
588 fn test_matches_pattern() {
589 assert!(matches_pattern("test.csv", "test.csv"));
590 assert!(matches_pattern("test.csv", "*.csv"));
591 assert!(matches_pattern("sales_2024.csv", "sales_*.csv"));
592 assert!(matches_pattern(
593 "/data/reports/sales.csv",
594 "/data/reports/*.csv"
595 ));
596 assert!(!matches_pattern("test.txt", "*.csv"));
597 assert!(!matches_pattern("sales.csv", "sales_*.csv"));
598 }
599
600 #[test]
601 fn test_pattern_specificity_bonus() {
602 assert_eq!(pattern_specificity_bonus("test.csv"), 10.0);
603 assert_eq!(pattern_specificity_bonus("*.csv"), 5.0);
604 assert_eq!(pattern_specificity_bonus("sales_*.csv"), 5.0);
605 }
606}