bssh/ssh/
config_cache.rs

1// Copyright 2025 Lablup Inc. and Jeongkyu Shin
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::ssh::SshConfig;
16use anyhow::{Context, Result};
17use lru::LruCache;
18use std::collections::HashMap;
19use std::path::{Path, PathBuf};
20use std::sync::{Arc, RwLock};
21use std::time::{Duration, Instant, SystemTime};
22use tracing::{debug, trace};
23
24/// Configuration options for the SSH config cache
25#[derive(Debug, Clone)]
26pub struct CacheConfig {
27    /// Maximum number of entries in the cache (default: 100)
28    pub max_entries: usize,
29    /// Time-to-live for cache entries (default: 300 seconds)
30    pub ttl: Duration,
31    /// Whether caching is enabled (default: true)
32    pub enabled: bool,
33}
34
35impl Default for CacheConfig {
36    fn default() -> Self {
37        Self {
38            max_entries: 100,
39            ttl: Duration::from_secs(300), // 5 minutes
40            enabled: true,
41        }
42    }
43}
44
45/// Metadata about a cached SSH config entry
46#[derive(Debug, Clone)]
47struct CacheEntry {
48    /// The cached SSH configuration
49    config: SshConfig,
50    /// When this entry was cached
51    cached_at: Instant,
52    /// File modification time when this entry was cached
53    file_mtime: SystemTime,
54    /// Number of times this entry has been accessed
55    access_count: u64,
56    /// Last access time
57    last_accessed: Instant,
58}
59
60impl CacheEntry {
61    fn new(config: SshConfig, file_mtime: SystemTime) -> Self {
62        let now = Instant::now();
63        Self {
64            config,
65            cached_at: now,
66            file_mtime,
67            access_count: 0,
68            last_accessed: now,
69        }
70    }
71
72    fn is_expired(&self, ttl: Duration) -> bool {
73        self.cached_at.elapsed() > ttl
74    }
75
76    fn is_stale(&self, current_mtime: SystemTime) -> bool {
77        self.file_mtime != current_mtime
78    }
79
80    fn access(&mut self) -> &SshConfig {
81        self.access_count += 1;
82        self.last_accessed = Instant::now();
83        &self.config
84    }
85}
86
87/// Cache statistics for monitoring and debugging
88#[derive(Debug, Clone, Default)]
89pub struct CacheStats {
90    /// Total number of cache hits
91    pub hits: u64,
92    /// Total number of cache misses
93    pub misses: u64,
94    /// Number of entries evicted due to TTL expiration
95    pub ttl_evictions: u64,
96    /// Number of entries evicted due to file modification
97    pub stale_evictions: u64,
98    /// Number of entries evicted due to LRU policy
99    pub lru_evictions: u64,
100    /// Current number of entries in cache
101    pub current_entries: usize,
102    /// Maximum number of entries allowed
103    pub max_entries: usize,
104}
105
106impl CacheStats {
107    pub fn hit_rate(&self) -> f64 {
108        let total = self.hits + self.misses;
109        if total == 0 {
110            0.0
111        } else {
112            self.hits as f64 / total as f64
113        }
114    }
115
116    pub fn miss_rate(&self) -> f64 {
117        1.0 - self.hit_rate()
118    }
119}
120
121/// Thread-safe LRU cache for SSH configurations
122pub struct SshConfigCache {
123    /// LRU cache implementation
124    cache: Arc<RwLock<LruCache<PathBuf, CacheEntry>>>,
125    /// Cache configuration
126    config: CacheConfig,
127    /// Cache statistics
128    stats: Arc<RwLock<CacheStats>>,
129}
130
131impl SshConfigCache {
132    /// Create a new SSH config cache with default configuration
133    pub fn new() -> Self {
134        Self::with_config(CacheConfig::default())
135    }
136
137    /// Create a new SSH config cache with custom configuration
138    pub fn with_config(config: CacheConfig) -> Self {
139        let cache_size = std::num::NonZeroUsize::new(config.max_entries)
140            .unwrap_or(std::num::NonZeroUsize::new(100).unwrap());
141
142        let stats = CacheStats {
143            max_entries: config.max_entries,
144            ..Default::default()
145        };
146
147        Self {
148            cache: Arc::new(RwLock::new(LruCache::new(cache_size))),
149            config,
150            stats: Arc::new(RwLock::new(stats)),
151        }
152    }
153
154    /// Get an SSH config from cache or load it from file
155    pub async fn get_or_load<P: AsRef<Path>>(&self, path: P) -> Result<SshConfig> {
156        if !self.config.enabled {
157            return SshConfig::load_from_file(path).await;
158        }
159
160        let path_ref = path.as_ref();
161        let path = tokio::fs::canonicalize(path_ref)
162            .await
163            .with_context(|| format!("Failed to canonicalize path: {}", path_ref.display()))?;
164
165        // Check if file exists and get its modification time
166        let file_metadata = tokio::fs::metadata(&path)
167            .await
168            .with_context(|| format!("Failed to read file metadata: {}", path.display()))?;
169
170        let current_mtime = file_metadata
171            .modified()
172            .with_context(|| format!("Failed to get modification time: {}", path.display()))?;
173
174        // Try to get from cache first
175        if let Some(config) = self.try_get_cached(&path, current_mtime)? {
176            return Ok(config);
177        }
178
179        // Cache miss - load from file
180        trace!("Cache miss for SSH config: {}", path.display());
181        let config = SshConfig::load_from_file(&path)
182            .await
183            .with_context(|| format!("Failed to load SSH config from file: {}", path.display()))?;
184
185        // Store in cache
186        self.put(path, config.clone(), current_mtime);
187
188        // Update statistics
189        {
190            let mut stats = self.stats.write().unwrap();
191            stats.misses += 1;
192        }
193
194        Ok(config)
195    }
196
197    /// Try to get a cached entry, checking for expiration and staleness
198    fn try_get_cached(&self, path: &Path, current_mtime: SystemTime) -> Result<Option<SshConfig>> {
199        let mut cache = self.cache.write().unwrap();
200
201        if let Some(entry) = cache.get_mut(path) {
202            // Check if entry is expired
203            if entry.is_expired(self.config.ttl) {
204                debug!("SSH config cache entry expired: {}", path.display());
205                cache.pop(path);
206
207                let mut stats = self.stats.write().unwrap();
208                stats.ttl_evictions += 1;
209                return Ok(None);
210            }
211
212            // Check if entry is stale (file was modified)
213            if entry.is_stale(current_mtime) {
214                debug!("SSH config cache entry stale: {}", path.display());
215                cache.pop(path);
216
217                let mut stats = self.stats.write().unwrap();
218                stats.stale_evictions += 1;
219                return Ok(None);
220            }
221
222            // Entry is valid - access it and return
223            let config = entry.access().clone();
224
225            // Update statistics
226            {
227                let mut stats = self.stats.write().unwrap();
228                stats.hits += 1;
229            }
230
231            trace!("SSH config cache hit: {}", path.display());
232            return Ok(Some(config));
233        }
234
235        Ok(None)
236    }
237
238    /// Put an entry in the cache
239    fn put(&self, path: PathBuf, config: SshConfig, file_mtime: SystemTime) {
240        let mut cache = self.cache.write().unwrap();
241
242        // Check if we're evicting an entry due to LRU policy
243        let will_evict = cache.len() >= cache.cap().get();
244
245        let entry = CacheEntry::new(config, file_mtime);
246        cache.put(path.clone(), entry);
247
248        // Update statistics
249        {
250            let mut stats = self.stats.write().unwrap();
251            if will_evict {
252                stats.lru_evictions += 1;
253            }
254            stats.current_entries = cache.len();
255        }
256
257        trace!("SSH config cached: {}", path.display());
258    }
259
260    /// Load SSH config from default locations with caching
261    pub async fn load_default(&self) -> Result<SshConfig> {
262        if !self.config.enabled {
263            return SshConfig::load_default().await;
264        }
265
266        // Try user-specific SSH config first
267        if let Some(home_dir) = dirs::home_dir() {
268            let user_config = home_dir.join(".ssh").join("config");
269            if tokio::fs::try_exists(&user_config).await.unwrap_or(false) {
270                return self.get_or_load(&user_config).await;
271            }
272        }
273
274        // Try system-wide SSH config
275        let system_config = Path::new("/etc/ssh/ssh_config");
276        if tokio::fs::try_exists(system_config).await.unwrap_or(false) {
277            return self.get_or_load(system_config).await;
278        }
279
280        // Return empty config if no files found
281        Ok(SshConfig::new())
282    }
283
284    /// Clear all entries from the cache
285    pub fn clear(&self) {
286        let mut cache = self.cache.write().unwrap();
287        cache.clear();
288
289        let mut stats = self.stats.write().unwrap();
290        stats.current_entries = 0;
291    }
292
293    /// Remove a specific entry from the cache
294    pub async fn remove<P: AsRef<Path>>(&self, path: P) -> Option<SshConfig> {
295        let path = path.as_ref();
296        if let Ok(canonical_path) = tokio::fs::canonicalize(path).await {
297            let mut cache = self.cache.write().unwrap();
298            let entry = cache.pop(&canonical_path)?;
299
300            let mut stats = self.stats.write().unwrap();
301            stats.current_entries = cache.len();
302
303            Some(entry.config)
304        } else {
305            None
306        }
307    }
308
309    /// Get current cache statistics
310    pub fn stats(&self) -> CacheStats {
311        self.stats.read().unwrap().clone()
312    }
313
314    /// Get cache configuration
315    pub fn config(&self) -> &CacheConfig {
316        &self.config
317    }
318
319    /// Update cache configuration (will clear cache if max_entries changed)
320    pub fn update_config(&mut self, new_config: CacheConfig) {
321        if new_config.max_entries != self.config.max_entries {
322            // Need to recreate cache with new size
323            let cache_size = std::num::NonZeroUsize::new(new_config.max_entries)
324                .unwrap_or(std::num::NonZeroUsize::new(100).unwrap());
325
326            self.cache = Arc::new(RwLock::new(LruCache::new(cache_size)));
327
328            let mut stats = self.stats.write().unwrap();
329            stats.max_entries = new_config.max_entries;
330            stats.current_entries = 0;
331        }
332
333        self.config = new_config;
334    }
335
336    /// Perform cache maintenance (remove expired and stale entries)
337    pub async fn maintain(&self) -> usize {
338        if !self.config.enabled {
339            return 0;
340        }
341
342        let mut to_remove = Vec::new();
343        let mut expired_count = 0;
344        let mut stale_count = 0;
345
346        // Collect keys to check and expired entries (can't remove while iterating)
347        // We'll use tokio::spawn to check file metadata concurrently
348        let mut check_tasks = Vec::new();
349
350        {
351            // Scope the lock to release it before awaiting
352            let cache = self.cache.write().unwrap();
353
354            for (path, entry) in cache.iter() {
355                if entry.is_expired(self.config.ttl) {
356                    to_remove.push(path.clone());
357                    expired_count += 1;
358                } else {
359                    let path_clone = path.clone();
360                    let entry_mtime = entry.file_mtime;
361                    check_tasks.push(tokio::spawn(async move {
362                        if let Ok(metadata) = tokio::fs::metadata(&path_clone).await {
363                            if let Ok(current_mtime) = metadata.modified() {
364                                (path_clone, entry_mtime != current_mtime, true)
365                            } else {
366                                (path_clone, false, false)
367                            }
368                        } else {
369                            // File doesn't exist anymore
370                            (path_clone, true, false)
371                        }
372                    }));
373                }
374            }
375        } // Lock is dropped here
376
377        // Wait for all file checks to complete
378        for task in check_tasks {
379            if let Ok((path, is_stale, _file_exists)) = task.await {
380                if is_stale {
381                    to_remove.push(path);
382                    stale_count += 1;
383                }
384            }
385        }
386
387        // Remove expired and stale entries
388        {
389            let mut cache = self.cache.write().unwrap();
390            for path in &to_remove {
391                cache.pop(path);
392            }
393        }
394
395        let removed_count = to_remove.len();
396
397        // Update statistics
398        {
399            let cache = self.cache.read().unwrap();
400            let mut stats = self.stats.write().unwrap();
401            stats.ttl_evictions += expired_count as u64;
402            stats.stale_evictions += stale_count as u64;
403            stats.current_entries = cache.len();
404        }
405
406        if removed_count > 0 {
407            debug!(
408                "SSH config cache maintenance: removed {} entries ({} expired, {} stale)",
409                removed_count, expired_count, stale_count
410            );
411        }
412
413        removed_count
414    }
415
416    /// Get detailed information about cache entries (for debugging)
417    pub fn debug_info(&self) -> HashMap<PathBuf, String> {
418        let cache = self.cache.read().unwrap();
419        let mut info = HashMap::new();
420
421        for (path, entry) in cache.iter() {
422            let age = entry.cached_at.elapsed();
423            let is_expired = entry.is_expired(self.config.ttl);
424            let last_accessed = entry.last_accessed.elapsed();
425
426            let status = if is_expired { "EXPIRED" } else { "VALID" };
427
428            info.insert(
429                path.clone(),
430                format!(
431                    "Status: {}, Age: {:?}, Accesses: {}, Last accessed: {:?} ago",
432                    status, age, entry.access_count, last_accessed
433                ),
434            );
435        }
436
437        info
438    }
439}
440
441impl Default for SshConfigCache {
442    fn default() -> Self {
443        Self::new()
444    }
445}
446
447// Global cache instance using once_cell for thread-safe lazy initialization
448use once_cell::sync::Lazy;
449
450/// Global SSH config cache instance
451pub static GLOBAL_CACHE: Lazy<SshConfigCache> = Lazy::new(|| {
452    let config = CacheConfig {
453        max_entries: std::env::var("BSSH_CACHE_SIZE")
454            .ok()
455            .and_then(|s| s.parse().ok())
456            .unwrap_or(100),
457        ttl: Duration::from_secs(
458            std::env::var("BSSH_CACHE_TTL")
459                .ok()
460                .and_then(|s| s.parse().ok())
461                .unwrap_or(300),
462        ),
463        enabled: std::env::var("BSSH_CACHE_ENABLED")
464            .map(|s| s.to_lowercase() != "false" && s != "0")
465            .unwrap_or(true),
466    };
467
468    debug!(
469        "Initializing SSH config cache with {} max entries, {:?} TTL, enabled: {}",
470        config.max_entries, config.ttl, config.enabled
471    );
472
473    SshConfigCache::with_config(config)
474});
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479    use std::io::Write;
480    use tempfile::NamedTempFile;
481
482    #[test]
483    fn test_cache_config_default() {
484        let config = CacheConfig::default();
485        assert_eq!(config.max_entries, 100);
486        assert_eq!(config.ttl, Duration::from_secs(300));
487        assert!(config.enabled);
488    }
489
490    #[test]
491    fn test_cache_entry_expiration() {
492        let config = SshConfig::new();
493        let mtime = SystemTime::now();
494        let mut entry = CacheEntry::new(config, mtime);
495
496        // Fresh entry should not be expired
497        assert!(!entry.is_expired(Duration::from_secs(300)));
498
499        // Simulate time passing by creating an old entry
500        entry.cached_at = Instant::now() - Duration::from_secs(400);
501        assert!(entry.is_expired(Duration::from_secs(300)));
502    }
503
504    #[test]
505    fn test_cache_entry_staleness() {
506        let config = SshConfig::new();
507        let old_mtime = SystemTime::UNIX_EPOCH;
508        let new_mtime = SystemTime::now();
509
510        let entry = CacheEntry::new(config, old_mtime);
511
512        assert!(!entry.is_stale(old_mtime));
513        assert!(entry.is_stale(new_mtime));
514    }
515
516    #[test]
517    fn test_cache_basic_operations() {
518        let cache = SshConfigCache::new();
519
520        // Create a temporary SSH config file
521        let mut temp_file = NamedTempFile::new().unwrap();
522        writeln!(temp_file, "Host example").unwrap();
523        writeln!(temp_file, "    HostName example.com").unwrap();
524
525        let path = temp_file.path().to_path_buf();
526
527        // First load should be a cache miss
528        let config1 = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
529        assert_eq!(config1.hosts.len(), 1);
530
531        let stats = cache.stats();
532        assert_eq!(stats.misses, 1);
533        assert_eq!(stats.hits, 0);
534
535        // Second load should be a cache hit
536        let config2 = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
537        assert_eq!(config2.hosts.len(), 1);
538
539        let stats = cache.stats();
540        assert_eq!(stats.misses, 1);
541        assert_eq!(stats.hits, 1);
542        assert_eq!(stats.hit_rate(), 0.5);
543    }
544
545    #[test]
546    fn test_cache_file_modification_detection() {
547        let cache = SshConfigCache::new();
548
549        let mut temp_file = NamedTempFile::new().unwrap();
550        writeln!(temp_file, "Host example").unwrap();
551        writeln!(temp_file, "    HostName example.com").unwrap();
552        temp_file.flush().unwrap();
553
554        let path = temp_file.path().to_path_buf();
555
556        // Load initial config
557        let config1 = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
558        assert_eq!(config1.hosts.len(), 1);
559
560        // Modify the file
561        std::thread::sleep(Duration::from_millis(10)); // Ensure different mtime
562        writeln!(temp_file, "Host another").unwrap();
563        writeln!(temp_file, "    HostName another.com").unwrap();
564        temp_file.flush().unwrap();
565
566        // Should detect file modification and reload
567        let config2 = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
568        assert_eq!(config2.hosts.len(), 2);
569
570        let stats = cache.stats();
571        assert_eq!(stats.stale_evictions, 1);
572    }
573
574    #[test]
575    fn test_cache_ttl_expiration() {
576        let config = CacheConfig {
577            max_entries: 10,
578            ttl: Duration::from_millis(50),
579            enabled: true,
580        };
581        let cache = SshConfigCache::with_config(config);
582
583        let mut temp_file = NamedTempFile::new().unwrap();
584        writeln!(temp_file, "Host example").unwrap();
585        writeln!(temp_file, "    HostName example.com").unwrap();
586
587        let path = temp_file.path().to_path_buf();
588
589        // Load initial config
590        let _config1 = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
591
592        // Wait for TTL to expire
593        std::thread::sleep(Duration::from_millis(100));
594
595        // Should reload due to TTL expiration
596        let _config2 = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
597
598        let stats = cache.stats();
599        assert_eq!(stats.ttl_evictions, 1);
600    }
601
602    #[test]
603    fn test_cache_clear_and_remove() {
604        let cache = SshConfigCache::new();
605
606        let mut temp_file = NamedTempFile::new().unwrap();
607        writeln!(temp_file, "Host example").unwrap();
608        writeln!(temp_file, "    HostName example.com").unwrap();
609
610        let path = temp_file.path().to_path_buf();
611
612        // Load config
613        let _config = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
614        assert_eq!(cache.stats().current_entries, 1);
615
616        // Remove specific entry
617        let removed_config = tokio_test::block_on(cache.remove(&path));
618        assert!(removed_config.is_some());
619        assert_eq!(cache.stats().current_entries, 0);
620
621        // Load again and clear all
622        let _config = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
623        assert_eq!(cache.stats().current_entries, 1);
624
625        cache.clear();
626        assert_eq!(cache.stats().current_entries, 0);
627    }
628
629    #[test]
630    fn test_cache_maintenance() {
631        let config = CacheConfig {
632            max_entries: 10,
633            ttl: Duration::from_millis(50),
634            enabled: true,
635        };
636        let cache = SshConfigCache::with_config(config);
637
638        let mut temp_file = NamedTempFile::new().unwrap();
639        writeln!(temp_file, "Host example").unwrap();
640        writeln!(temp_file, "    HostName example.com").unwrap();
641
642        let path = temp_file.path().to_path_buf();
643
644        // Load config
645        let _config = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
646        assert_eq!(cache.stats().current_entries, 1);
647
648        // Wait for expiration
649        std::thread::sleep(Duration::from_millis(100));
650
651        // Run maintenance
652        let removed = tokio_test::block_on(cache.maintain());
653        assert_eq!(removed, 1);
654        assert_eq!(cache.stats().current_entries, 0);
655    }
656
657    #[test]
658    fn test_cache_disabled() {
659        let config = CacheConfig {
660            max_entries: 10,
661            ttl: Duration::from_secs(300),
662            enabled: false,
663        };
664        let cache = SshConfigCache::with_config(config);
665
666        let mut temp_file = NamedTempFile::new().unwrap();
667        writeln!(temp_file, "Host example").unwrap();
668        writeln!(temp_file, "    HostName example.com").unwrap();
669
670        let path = temp_file.path().to_path_buf();
671
672        // Should not use cache when disabled
673        let _config1 = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
674        let _config2 = tokio_test::block_on(cache.get_or_load(&path)).unwrap();
675
676        let stats = cache.stats();
677        assert_eq!(stats.hits, 0);
678        assert_eq!(stats.misses, 0);
679        assert_eq!(stats.current_entries, 0);
680    }
681
682    #[test]
683    fn test_cache_stats() {
684        let cache = SshConfigCache::new();
685        let stats = cache.stats();
686
687        assert_eq!(stats.hits, 0);
688        assert_eq!(stats.misses, 0);
689        assert_eq!(stats.hit_rate(), 0.0);
690        assert_eq!(stats.miss_rate(), 1.0);
691        assert_eq!(stats.current_entries, 0);
692        assert_eq!(stats.max_entries, 100);
693    }
694}