prometheus_mcp/mcp/
repository.rs

1use std::sync::{Arc, RwLock};
2use std::time::{Duration, Instant};
3
4use async_trait::async_trait;
5
6use crate::mcp::prometheus_client::{
7    MetricMetadata, PrometheusClient, PrometheusError, PrometheusQueryResult,
8};
9use crate::mcp::prometheus_config::PrometheusConfig;
10
11use once_cell::sync::Lazy;
12
13#[async_trait]
14pub trait PrometheusRepository: Send + Sync {
15    async fn query(
16        &self,
17        query: &str,
18        time: Option<&str>,
19    ) -> Result<PrometheusQueryResult, PrometheusError>;
20    async fn query_range(
21        &self,
22        query: &str,
23        start: &str,
24        end: &str,
25        step: &str,
26    ) -> Result<PrometheusQueryResult, PrometheusError>;
27    async fn list_metrics(&self) -> Result<Vec<String>, PrometheusError>;
28    async fn get_metadata(&self, metric: &str) -> Result<Vec<MetricMetadata>, PrometheusError>;
29    async fn get_series(
30        &self,
31        match_strings: Vec<&str>,
32    ) -> Result<Vec<std::collections::HashMap<String, String>>, PrometheusError>;
33    async fn get_label_values(&self, label_name: &str) -> Result<Vec<String>, PrometheusError>;
34}
35
36pub struct HttpPrometheusRepository {
37    client: PrometheusClient,
38    // Simple caches
39    metrics_cache: RwLock<Option<(Instant, Vec<String>)>>, // cache for list_metrics
40    labels_cache: RwLock<std::collections::HashMap<String, (Instant, Vec<String>)>>, // per-label cache
41    cache_ttl: Duration,
42}
43
44impl HttpPrometheusRepository {
45    pub fn new(config: PrometheusConfig) -> Result<Self, PrometheusError> {
46        let client = PrometheusClient::new(config.clone())?;
47        let ttl = config
48            .cache_ttl_secs
49            .map(Duration::from_secs)
50            .unwrap_or_else(|| Duration::from_secs(0));
51        Ok(Self {
52            client,
53            metrics_cache: RwLock::new(None),
54            labels_cache: RwLock::new(std::collections::HashMap::new()),
55            cache_ttl: ttl,
56        })
57    }
58
59    pub fn from_env() -> Result<Self, PrometheusError> {
60        Self::new(PrometheusConfig::from_env())
61    }
62
63    fn is_expired(ts: Instant, ttl: Duration) -> bool {
64        ttl > Duration::from_secs(0) && ts.elapsed() > ttl
65    }
66}
67
68#[async_trait]
69impl PrometheusRepository for HttpPrometheusRepository {
70    async fn query(
71        &self,
72        query: &str,
73        time: Option<&str>,
74    ) -> Result<PrometheusQueryResult, PrometheusError> {
75        self.client.query(query, time).await
76    }
77
78    async fn query_range(
79        &self,
80        query: &str,
81        start: &str,
82        end: &str,
83        step: &str,
84    ) -> Result<PrometheusQueryResult, PrometheusError> {
85        self.client.query_range(query, start, end, step).await
86    }
87
88    async fn list_metrics(&self) -> Result<Vec<String>, PrometheusError> {
89        // Try cache
90        if self.cache_ttl > Duration::from_secs(0) {
91            if let Some((ts, cached)) = self.metrics_cache.read().unwrap().as_ref() {
92                if !Self::is_expired(*ts, self.cache_ttl) {
93                    return Ok(cached.clone());
94                }
95            }
96        }
97        let fresh = self.client.list_metrics().await?;
98        if self.cache_ttl > Duration::from_secs(0) {
99            *self.metrics_cache.write().unwrap() = Some((Instant::now(), fresh.clone()));
100        }
101        Ok(fresh)
102    }
103
104    async fn get_metadata(&self, metric: &str) -> Result<Vec<MetricMetadata>, PrometheusError> {
105        self.client.get_metadata(metric).await
106    }
107
108    async fn get_series(
109        &self,
110        match_strings: Vec<&str>,
111    ) -> Result<Vec<std::collections::HashMap<String, String>>, PrometheusError> {
112        self.client.get_series(match_strings).await
113    }
114
115    async fn get_label_values(&self, label_name: &str) -> Result<Vec<String>, PrometheusError> {
116        if self.cache_ttl > Duration::from_secs(0) {
117            if let Some((ts, cached)) = self.labels_cache.read().unwrap().get(label_name) {
118                if !Self::is_expired(*ts, self.cache_ttl) {
119                    return Ok(cached.clone());
120                }
121            }
122        }
123        let fresh = self.client.get_label_values(label_name).await?;
124        if self.cache_ttl > Duration::from_secs(0) {
125            self.labels_cache
126                .write()
127                .unwrap()
128                .insert(label_name.to_string(), (Instant::now(), fresh.clone()));
129        }
130        Ok(fresh)
131    }
132}
133
134static REPO: Lazy<RwLock<Option<Arc<dyn PrometheusRepository>>>> = Lazy::new(|| RwLock::new(None));
135
136pub fn get_repository() -> Arc<dyn PrometheusRepository> {
137    if let Some(repo) = REPO.read().unwrap().as_ref() {
138        return Arc::clone(repo);
139    }
140    // Build default repo
141    match HttpPrometheusRepository::from_env() {
142        Ok(http) => {
143            let arc: Arc<dyn PrometheusRepository> = Arc::new(http);
144            *REPO.write().unwrap() = Some(Arc::clone(&arc));
145            arc
146        }
147        Err(err) => {
148            // Fallback repo that returns the error on all calls
149            struct ErrRepo {
150                err: PrometheusError,
151            }
152            #[async_trait]
153            impl PrometheusRepository for ErrRepo {
154                async fn query(
155                    &self,
156                    _query: &str,
157                    _time: Option<&str>,
158                ) -> Result<PrometheusQueryResult, PrometheusError> {
159                    Err(PrometheusError::ApiError(format!(
160                        "Repository init error: {:?}",
161                        self.err
162                    )))
163                }
164                async fn query_range(
165                    &self,
166                    _query: &str,
167                    _start: &str,
168                    _end: &str,
169                    _step: &str,
170                ) -> Result<PrometheusQueryResult, PrometheusError> {
171                    Err(PrometheusError::ApiError(format!(
172                        "Repository init error: {:?}",
173                        self.err
174                    )))
175                }
176                async fn list_metrics(&self) -> Result<Vec<String>, PrometheusError> {
177                    Err(PrometheusError::ApiError(format!(
178                        "Repository init error: {:?}",
179                        self.err
180                    )))
181                }
182                async fn get_metadata(
183                    &self,
184                    _metric: &str,
185                ) -> Result<Vec<MetricMetadata>, PrometheusError> {
186                    Err(PrometheusError::ApiError(format!(
187                        "Repository init error: {:?}",
188                        self.err
189                    )))
190                }
191                async fn get_series(
192                    &self,
193                    _match_strings: Vec<&str>,
194                ) -> Result<Vec<std::collections::HashMap<String, String>>, PrometheusError>
195                {
196                    Err(PrometheusError::ApiError(format!(
197                        "Repository init error: {:?}",
198                        self.err
199                    )))
200                }
201                async fn get_label_values(
202                    &self,
203                    _label_name: &str,
204                ) -> Result<Vec<String>, PrometheusError> {
205                    Err(PrometheusError::ApiError(format!(
206                        "Repository init error: {:?}",
207                        self.err
208                    )))
209                }
210            }
211            let arc: Arc<dyn PrometheusRepository> = Arc::new(ErrRepo { err });
212            *REPO.write().unwrap() = Some(Arc::clone(&arc));
213            arc
214        }
215    }
216}
217
218/// Override the repository instance (DI for tests or custom setups)
219#[allow(dead_code)]
220pub fn set_repository(repo: Arc<dyn PrometheusRepository>) {
221    *REPO.write().unwrap() = Some(repo);
222}
223
224/// Testing-only: legacy helper; prefer set_repository
225#[cfg(test)]
226pub fn set_repository_for_tests(repo: Arc<dyn PrometheusRepository>) {
227    *REPO.write().unwrap() = Some(repo);
228}