1use crate::ssh::SshConfig;
16use anyhow::{Context, Result};
17use lru::LruCache;
18use std::collections::HashMap;
19use std::path::{Path, PathBuf};
20use std::sync::{Arc, RwLock};
21use std::time::{Duration, Instant, SystemTime};
22use tracing::{debug, trace};
23
24#[derive(Debug, Clone)]
26pub struct CacheConfig {
27 pub max_entries: usize,
29 pub ttl: Duration,
31 pub enabled: bool,
33}
34
35impl Default for CacheConfig {
36 fn default() -> Self {
37 Self {
38 max_entries: 100,
39 ttl: Duration::from_secs(300), enabled: true,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47struct CacheEntry {
48 config: SshConfig,
50 cached_at: Instant,
52 file_mtime: SystemTime,
54 access_count: u64,
56 last_accessed: Instant,
58}
59
60impl CacheEntry {
61 fn new(config: SshConfig, file_mtime: SystemTime) -> Self {
62 let now = Instant::now();
63 Self {
64 config,
65 cached_at: now,
66 file_mtime,
67 access_count: 0,
68 last_accessed: now,
69 }
70 }
71
72 fn is_expired(&self, ttl: Duration) -> bool {
73 self.cached_at.elapsed() > ttl
74 }
75
76 fn is_stale(&self, current_mtime: SystemTime) -> bool {
77 self.file_mtime != current_mtime
78 }
79
80 fn access(&mut self) -> &SshConfig {
81 self.access_count += 1;
82 self.last_accessed = Instant::now();
83 &self.config
84 }
85}
86
87#[derive(Debug, Clone, Default)]
89pub struct CacheStats {
90 pub hits: u64,
92 pub misses: u64,
94 pub ttl_evictions: u64,
96 pub stale_evictions: u64,
98 pub lru_evictions: u64,
100 pub current_entries: usize,
102 pub max_entries: usize,
104}
105
106impl CacheStats {
107 pub fn hit_rate(&self) -> f64 {
108 let total = self.hits + self.misses;
109 if total == 0 {
110 0.0
111 } else {
112 self.hits as f64 / total as f64
113 }
114 }
115
116 pub fn miss_rate(&self) -> f64 {
117 1.0 - self.hit_rate()
118 }
119}
120
121pub struct SshConfigCache {
123 cache: Arc<RwLock<LruCache<PathBuf, CacheEntry>>>,
125 config: CacheConfig,
127 stats: Arc<RwLock<CacheStats>>,
129}
130
131impl SshConfigCache {
132 pub fn new() -> Self {
134 Self::with_config(CacheConfig::default())
135 }
136
137 pub fn with_config(config: CacheConfig) -> Self {
139 let cache_size = std::num::NonZeroUsize::new(config.max_entries)
140 .unwrap_or(std::num::NonZeroUsize::new(100).unwrap());
141
142 let stats = CacheStats {
143 max_entries: config.max_entries,
144 ..Default::default()
145 };
146
147 Self {
148 cache: Arc::new(RwLock::new(LruCache::new(cache_size))),
149 config,
150 stats: Arc::new(RwLock::new(stats)),
151 }
152 }
153
154 pub async fn get_or_load<P: AsRef<Path>>(&self, path: P) -> Result<SshConfig> {
156 if !self.config.enabled {
157 return SshConfig::load_from_file(path).await;
158 }
159
160 let path_ref = path.as_ref();
161 let path = tokio::fs::canonicalize(path_ref)
162 .await
163 .with_context(|| format!("Failed to canonicalize path: {}", path_ref.display()))?;
164
165 let file_metadata = tokio::fs::metadata(&path)
167 .await
168 .with_context(|| format!("Failed to read file metadata: {}", path.display()))?;
169
170 let current_mtime = file_metadata
171 .modified()
172 .with_context(|| format!("Failed to get modification time: {}", path.display()))?;
173
174 if let Some(config) = self.try_get_cached(&path, current_mtime)? {
176 return Ok(config);
177 }
178
179 trace!("Cache miss for SSH config: {}", path.display());
181 let config = SshConfig::load_from_file(&path)
182 .await
183 .with_context(|| format!("Failed to load SSH config from file: {}", path.display()))?;
184
185 self.put(path, config.clone(), current_mtime);
187
188 {
190 let mut stats = self.stats.write().unwrap();
191 stats.misses += 1;
192 }
193
194 Ok(config)
195 }
196
197 fn try_get_cached(&self, path: &Path, current_mtime: SystemTime) -> Result<Option<SshConfig>> {
199 let mut cache = self.cache.write().unwrap();
200
201 if let Some(entry) = cache.get_mut(path) {
202 if entry.is_expired(self.config.ttl) {
204 debug!("SSH config cache entry expired: {}", path.display());
205 cache.pop(path);
206
207 let mut stats = self.stats.write().unwrap();
208 stats.ttl_evictions += 1;
209 return Ok(None);
210 }
211
212 if entry.is_stale(current_mtime) {
214 debug!("SSH config cache entry stale: {}", path.display());
215 cache.pop(path);
216
217 let mut stats = self.stats.write().unwrap();
218 stats.stale_evictions += 1;
219 return Ok(None);
220 }
221
222 let config = entry.access().clone();
224
225 {
227 let mut stats = self.stats.write().unwrap();
228 stats.hits += 1;
229 }
230
231 trace!("SSH config cache hit: {}", path.display());
232 return Ok(Some(config));
233 }
234
235 Ok(None)
236 }
237
238 fn put(&self, path: PathBuf, config: SshConfig, file_mtime: SystemTime) {
240 let mut cache = self.cache.write().unwrap();
241
242 let will_evict = cache.len() >= cache.cap().get();
244
245 let entry = CacheEntry::new(config, file_mtime);
246 cache.put(path.clone(), entry);
247
248 {
250 let mut stats = self.stats.write().unwrap();
251 if will_evict {
252 stats.lru_evictions += 1;
253 }
254 stats.current_entries = cache.len();
255 }
256
257 trace!("SSH config cached: {}", path.display());
258 }
259
260 pub async fn load_default(&self) -> Result<SshConfig> {
262 if !self.config.enabled {
263 return SshConfig::load_default().await;
264 }
265
266 if let Some(home_dir) = dirs::home_dir() {
268 let user_config = home_dir.join(".ssh").join("config");
269 if tokio::fs::try_exists(&user_config).await.unwrap_or(false) {
270 return self.get_or_load(&user_config).await;
271 }
272 }
273
274 let system_config = Path::new("/etc/ssh/ssh_config");
276 if tokio::fs::try_exists(system_config).await.unwrap_or(false) {
277 return self.get_or_load(system_config).await;
278 }
279
280 Ok(SshConfig::new())
282 }
283
284 pub fn clear(&self) {
286 let mut cache = self.cache.write().unwrap();
287 cache.clear();
288
289 let mut stats = self.stats.write().unwrap();
290 stats.current_entries = 0;
291 }
292
293 pub async fn remove<P: AsRef<Path>>(&self, path: P) -> Option<SshConfig> {
295 let path = path.as_ref();
296 if let Ok(canonical_path) = tokio::fs::canonicalize(path).await {
297 let mut cache = self.cache.write().unwrap();
298 let entry = cache.pop(&canonical_path)?;
299
300 let mut stats = self.stats.write().unwrap();
301 stats.current_entries = cache.len();
302
303 Some(entry.config)
304 } else {
305 None
306 }
307 }
308
309 pub fn stats(&self) -> CacheStats {
311 self.stats.read().unwrap().clone()
312 }
313
314 pub fn config(&self) -> &CacheConfig {
316 &self.config
317 }
318
319 pub fn update_config(&mut self, new_config: CacheConfig) {
321 if new_config.max_entries != self.config.max_entries {
322 let cache_size = std::num::NonZeroUsize::new(new_config.max_entries)
324 .unwrap_or(std::num::NonZeroUsize::new(100).unwrap());
325
326 self.cache = Arc::new(RwLock::new(LruCache::new(cache_size)));
327
328 let mut stats = self.stats.write().unwrap();
329 stats.max_entries = new_config.max_entries;
330 stats.current_entries = 0;
331 }
332
333 self.config = new_config;
334 }
335
336 pub async fn maintain(&self) -> usize {
338 if !self.config.enabled {
339 return 0;
340 }
341
342 let mut to_remove = Vec::new();
343 let mut expired_count = 0;
344 let mut stale_count = 0;
345
346 let mut check_tasks = Vec::new();
349
350 {
351 let cache = self.cache.write().unwrap();
353
354 for (path, entry) in cache.iter() {
355 if entry.is_expired(self.config.ttl) {
356 to_remove.push(path.clone());
357 expired_count += 1;
358 } else {
359 let path_clone = path.clone();
360 let entry_mtime = entry.file_mtime;
361 check_tasks.push(tokio::spawn(async move {
362 if let Ok(metadata) = tokio::fs::metadata(&path_clone).await {
363 if let Ok(current_mtime) = metadata.modified() {
364 (path_clone, entry_mtime != current_mtime, true)
365 } else {
366 (path_clone, false, false)
367 }
368 } else {
369 (path_clone, true, false)
371 }
372 }));
373 }
374 }
375 } for task in check_tasks {
379 if let Ok((path, is_stale, _file_exists)) = task.await {
380 if is_stale {
381 to_remove.push(path);
382 stale_count += 1;
383 }
384 }
385 }
386
387 {
389 let mut cache = self.cache.write().unwrap();
390 for path in &to_remove {
391 cache.pop(path);
392 }
393 }
394
395 let removed_count = to_remove.len();
396
397 {
399 let cache = self.cache.read().unwrap();
400 let mut stats = self.stats.write().unwrap();
401 stats.ttl_evictions += expired_count as u64;
402 stats.stale_evictions += stale_count as u64;
403 stats.current_entries = cache.len();
404 }
405
406 if removed_count > 0 {
407 debug!(
408 "SSH config cache maintenance: removed {} entries ({} expired, {} stale)",
409 removed_count, expired_count, stale_count
410 );
411 }
412
413 removed_count
414 }
415
416 pub fn debug_info(&self) -> HashMap<PathBuf, String> {
418 let cache = self.cache.read().unwrap();
419 let mut info = HashMap::new();
420
421 for (path, entry) in cache.iter() {
422 let age = entry.cached_at.elapsed();
423 let is_expired = entry.is_expired(self.config.ttl);
424 let last_accessed = entry.last_accessed.elapsed();
425
426 let status = if is_expired { "EXPIRED" } else { "VALID" };
427
428 info.insert(
429 path.clone(),
430 format!(
431 "Status: {}, Age: {:?}, Accesses: {}, Last accessed: {:?} ago",
432 status, age, entry.access_count, last_accessed
433 ),
434 );
435 }
436
437 info
438 }
439}
440
441impl Default for SshConfigCache {
442 fn default() -> Self {
443 Self::new()
444 }
445}
446
447use once_cell::sync::Lazy;
449
450pub static GLOBAL_CACHE: Lazy<SshConfigCache> = Lazy::new(|| {
452 let config = CacheConfig {
453 max_entries: std::env::var("BSSH_CACHE_SIZE")
454 .ok()
455 .and_then(|s| s.parse().ok())
456 .unwrap_or(100),
457 ttl: Duration::from_secs(
458 std::env::var("BSSH_CACHE_TTL")
459 .ok()
460 .and_then(|s| s.parse().ok())
461 .unwrap_or(300),
462 ),
463 enabled: std::env::var("BSSH_CACHE_ENABLED")
464 .map(|s| s.to_lowercase() != "false" && s != "0")
465 .unwrap_or(true),
466 };
467
468 debug!(
469 "Initializing SSH config cache with {} max entries, {:?} TTL, enabled: {}",
470 config.max_entries, config.ttl, config.enabled
471 );
472
473 SshConfigCache::with_config(config)
474});
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479 use std::io::Write;
480 use tempfile::NamedTempFile;
481
482 #[test]
483 fn test_cache_config_default() {
484 let config = CacheConfig::default();
485 assert_eq!(config.max_entries, 100);
486 assert_eq!(config.ttl, Duration::from_secs(300));
487 assert!(config.enabled);
488 }
489
490 #[test]
491 fn test_cache_entry_expiration() {
492 let config = SshConfig::new();
493 let mtime = SystemTime::now();
494 let mut entry = CacheEntry::new(config, mtime);
495
496 assert!(!entry.is_expired(Duration::from_secs(300)));
498
499 entry.cached_at = Instant::now() - Duration::from_secs(400);
501 assert!(entry.is_expired(Duration::from_secs(300)));
502 }
503
504 #[test]
505 fn test_cache_entry_staleness() {
506 let config = SshConfig::new();
507 let old_mtime = SystemTime::UNIX_EPOCH;
508 let new_mtime = SystemTime::now();
509
510 let entry = CacheEntry::new(config, old_mtime);
511
512 assert!(!entry.is_stale(old_mtime));
513 assert!(entry.is_stale(new_mtime));
514 }
515
516 #[test]
517 fn test_cache_basic_operations() {
518 let cache = SshConfigCache::new();
519
520 let mut temp_file = NamedTempFile::new().unwrap();
522 writeln!(temp_file, "Host example").unwrap();
523 writeln!(temp_file, " HostName example.com").unwrap();
524
525 let path = temp_file.path().to_path_buf();
526
527 let config1 = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
529 assert_eq!(config1.hosts.len(), 1);
530
531 let stats = cache.stats();
532 assert_eq!(stats.misses, 1);
533 assert_eq!(stats.hits, 0);
534
535 let config2 = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
537 assert_eq!(config2.hosts.len(), 1);
538
539 let stats = cache.stats();
540 assert_eq!(stats.misses, 1);
541 assert_eq!(stats.hits, 1);
542 assert_eq!(stats.hit_rate(), 0.5);
543 }
544
545 #[test]
546 fn test_cache_file_modification_detection() {
547 let cache = SshConfigCache::new();
548
549 let mut temp_file = NamedTempFile::new().unwrap();
550 writeln!(temp_file, "Host example").unwrap();
551 writeln!(temp_file, " HostName example.com").unwrap();
552 temp_file.flush().unwrap();
553
554 let path = temp_file.path().to_path_buf();
555
556 let config1 = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
558 assert_eq!(config1.hosts.len(), 1);
559
560 std::thread::sleep(Duration::from_millis(10)); writeln!(temp_file, "Host another").unwrap();
563 writeln!(temp_file, " HostName another.com").unwrap();
564 temp_file.flush().unwrap();
565
566 let config2 = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
568 assert_eq!(config2.hosts.len(), 2);
569
570 let stats = cache.stats();
571 assert_eq!(stats.stale_evictions, 1);
572 }
573
574 #[test]
575 fn test_cache_ttl_expiration() {
576 let config = CacheConfig {
577 max_entries: 10,
578 ttl: Duration::from_millis(50),
579 enabled: true,
580 };
581 let cache = SshConfigCache::with_config(config);
582
583 let mut temp_file = NamedTempFile::new().unwrap();
584 writeln!(temp_file, "Host example").unwrap();
585 writeln!(temp_file, " HostName example.com").unwrap();
586
587 let path = temp_file.path().to_path_buf();
588
589 let _config1 = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
591
592 std::thread::sleep(Duration::from_millis(100));
594
595 let _config2 = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
597
598 let stats = cache.stats();
599 assert_eq!(stats.ttl_evictions, 1);
600 }
601
602 #[test]
603 fn test_cache_clear_and_remove() {
604 let cache = SshConfigCache::new();
605
606 let mut temp_file = NamedTempFile::new().unwrap();
607 writeln!(temp_file, "Host example").unwrap();
608 writeln!(temp_file, " HostName example.com").unwrap();
609
610 let path = temp_file.path().to_path_buf();
611
612 let _config = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
614 assert_eq!(cache.stats().current_entries, 1);
615
616 let removed_config = tokio_test::block_on(cache.remove(&path));
618 assert!(removed_config.is_some());
619 assert_eq!(cache.stats().current_entries, 0);
620
621 let _config = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
623 assert_eq!(cache.stats().current_entries, 1);
624
625 cache.clear();
626 assert_eq!(cache.stats().current_entries, 0);
627 }
628
629 #[test]
630 fn test_cache_maintenance() {
631 let config = CacheConfig {
632 max_entries: 10,
633 ttl: Duration::from_millis(50),
634 enabled: true,
635 };
636 let cache = SshConfigCache::with_config(config);
637
638 let mut temp_file = NamedTempFile::new().unwrap();
639 writeln!(temp_file, "Host example").unwrap();
640 writeln!(temp_file, " HostName example.com").unwrap();
641
642 let path = temp_file.path().to_path_buf();
643
644 let _config = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
646 assert_eq!(cache.stats().current_entries, 1);
647
648 std::thread::sleep(Duration::from_millis(100));
650
651 let removed = tokio_test::block_on(cache.maintain());
653 assert_eq!(removed, 1);
654 assert_eq!(cache.stats().current_entries, 0);
655 }
656
657 #[test]
658 fn test_cache_disabled() {
659 let config = CacheConfig {
660 max_entries: 10,
661 ttl: Duration::from_secs(300),
662 enabled: false,
663 };
664 let cache = SshConfigCache::with_config(config);
665
666 let mut temp_file = NamedTempFile::new().unwrap();
667 writeln!(temp_file, "Host example").unwrap();
668 writeln!(temp_file, " HostName example.com").unwrap();
669
670 let path = temp_file.path().to_path_buf();
671
672 let _config1 = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
674 let _config2 = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
675
676 let stats = cache.stats();
677 assert_eq!(stats.hits, 0);
678 assert_eq!(stats.misses, 0);
679 assert_eq!(stats.current_entries, 0);
680 }
681
682 #[test]
683 fn test_cache_stats() {
684 let cache = SshConfigCache::new();
685 let stats = cache.stats();
686
687 assert_eq!(stats.hits, 0);
688 assert_eq!(stats.misses, 0);
689 assert_eq!(stats.hit_rate(), 0.0);
690 assert_eq!(stats.miss_rate(), 1.0);
691 assert_eq!(stats.current_entries, 0);
692 assert_eq!(stats.max_entries, 100);
693 }
694}