scouter_client/http/
client.rs

1#![allow(clippy::useless_conversion)]
2use crate::error::ClientError;
3use pyo3::{prelude::*, IntoPyObjectExt};
4use scouter_settings::http::HttpConfig;
5use scouter_types::contracts::{
6    DriftAlertRequest, DriftRequest, GetProfileRequest, ProfileRequest, ProfileStatusRequest,
7};
8use scouter_types::http::{RequestType, Routes};
9use scouter_types::sql::TraceFilters;
10use scouter_types::{
11    RegisteredProfileResponse, TagsRequest, TagsResponse, TraceBaggageResponse,
12    TraceMetricsRequest, TraceMetricsResponse, TracePaginationResponse, TraceRequest,
13    TraceSpansResponse,
14};
15
16use crate::http::HttpClient;
17use scouter_types::{
18    alert::Alert, psi::BinnedPsiFeatureMetrics, spc::SpcDriftFeatures, BinnedMetrics, DriftProfile,
19    DriftType, PyHelperFuncs,
20};
21use std::path::PathBuf;
22use tracing::{debug, error};
23
24pub const DOWNLOAD_CHUNK_SIZE: usize = 1024 * 1024 * 5;
25
26#[derive(Debug, Clone)]
27pub struct ScouterClient {
28    client: HttpClient,
29}
30
31impl ScouterClient {
32    pub fn new(config: Option<HttpConfig>) -> Result<Self, ClientError> {
33        let client = HttpClient::new(config.unwrap_or_default())?;
34
35        Ok(ScouterClient { client })
36    }
37
38    /// Insert a profile into the scouter server
39    pub fn insert_profile(
40        &self,
41        request: &ProfileRequest,
42    ) -> Result<RegisteredProfileResponse, ClientError> {
43        let response = self.client.request(
44            Routes::Profile,
45            RequestType::Post,
46            Some(serde_json::to_value(request).unwrap()),
47            None,
48            None,
49        )?;
50
51        if response.status().is_success() {
52            let body = response.bytes()?;
53            let profile_response: RegisteredProfileResponse = serde_json::from_slice(&body)?;
54
55            debug!("Profile inserted successfully: {:?}", profile_response);
56            Ok(profile_response)
57        } else {
58            Err(ClientError::InsertProfileError)
59        }
60    }
61
62    pub fn update_profile_status(
63        &self,
64        request: &ProfileStatusRequest,
65    ) -> Result<bool, ClientError> {
66        let response = self.client.request(
67            Routes::ProfileStatus,
68            RequestType::Put,
69            Some(serde_json::to_value(request).unwrap()),
70            None,
71            None,
72        )?;
73
74        if response.status().is_success() {
75            Ok(true)
76        } else {
77            Err(ClientError::UpdateProfileError)
78        }
79    }
80
81    pub fn get_alerts(&self, request: &DriftAlertRequest) -> Result<Vec<Alert>, ClientError> {
82        debug!("Getting alerts for: {:?}", request);
83
84        let query_string = serde_qs::to_string(request)?;
85
86        let response = self.client.request(
87            Routes::Alerts,
88            RequestType::Get,
89            None,
90            Some(query_string),
91            None,
92        )?;
93
94        // Check response status
95        if !response.status().is_success() {
96            return Err(ClientError::GetDriftAlertError);
97        }
98
99        // Parse response body
100        let body: serde_json::Value = response.json()?;
101
102        // Extract alerts from response
103        let alerts = body
104            .get("alerts")
105            .map(|alerts| {
106                serde_json::from_value::<Vec<Alert>>(alerts.clone()).inspect_err(|e| {
107                    error!(
108                        "Failed to parse drift alerts {:?}. Error: {:?}",
109                        &request, e
110                    );
111                })
112            })
113            .unwrap_or_else(|| {
114                error!("No alerts found in response");
115                Ok(Vec::new())
116            })?;
117
118        Ok(alerts)
119    }
120
121    pub fn get_drift_profile(
122        &self,
123        request: GetProfileRequest,
124    ) -> Result<DriftProfile, ClientError> {
125        let query_string = serde_qs::to_string(&request)?;
126
127        let response = self.client.request(
128            Routes::Profile,
129            RequestType::Get,
130            None,
131            Some(query_string),
132            None,
133        )?;
134
135        // Early return for error status codes
136        if !response.status().is_success() {
137            error!("Failed to get profile. Status: {:?}", response.status());
138            return Err(ClientError::GetDriftProfileError);
139        }
140
141        // Get response body
142        let body = response.bytes()?;
143
144        // Parse JSON response
145        let profile: DriftProfile = serde_json::from_slice(&body)?;
146
147        Ok(profile)
148    }
149
150    /// Check if the scouter server is healthy
151    pub fn check_service_health(&self) -> Result<bool, ClientError> {
152        let response = self
153            .client
154            .request(Routes::Healthcheck, RequestType::Get, None, None, None)
155            .inspect_err(|e| {
156                error!("Failed to check scouter health {}", e);
157            })?;
158
159        if response.status() == 200 {
160            Ok(true)
161        } else {
162            Ok(false)
163        }
164    }
165
166    fn get_paginated_traces(
167        &self,
168        request: &TraceFilters,
169    ) -> Result<TracePaginationResponse, ClientError> {
170        let response = self.client.request(
171            Routes::PaginatedTraces,
172            RequestType::Post,
173            Some(serde_json::to_value(request).unwrap()),
174            None,
175            None,
176        )?;
177
178        if !response.status().is_success() {
179            let status_code = response.status();
180            let err_msg = response.text().unwrap_or_default();
181            error!(
182                "Failed to get paginated traces. Status: {:?}, Error: {}",
183                status_code, err_msg
184            );
185            return Err(ClientError::GetPaginatedTracesError);
186        }
187
188        // Get response body
189        let body = response.bytes()?;
190
191        // Parse JSON response
192        let response: TracePaginationResponse = serde_json::from_slice(&body)?;
193        Ok(response)
194    }
195
196    fn refresh_trace_summary(&self) -> Result<bool, ClientError> {
197        let response = self.client.request(
198            Routes::RefreshTraceSummary,
199            RequestType::Get,
200            None,
201            None,
202            None,
203        )?;
204        if !response.status().is_success() {
205            error!(
206                "Failed to refresh trace summary. Status: {:?}",
207                response.status()
208            );
209            return Err(ClientError::RefreshTraceSummaryError);
210        }
211
212        Ok(true)
213    }
214
215    fn get_trace_spans(&self, trace_id: &str) -> Result<TraceSpansResponse, ClientError> {
216        let trace_request = TraceRequest {
217            trace_id: trace_id.to_string(),
218        };
219
220        let query_string = serde_qs::to_string(&trace_request)?;
221
222        let response = self.client.request(
223            Routes::TraceSpans,
224            RequestType::Get,
225            None,
226            Some(query_string),
227            None,
228        )?;
229        if !response.status().is_success() {
230            error!("Failed to get trace spans. Status: {:?}", response.status());
231            return Err(ClientError::GetTraceSpansError);
232        }
233
234        // Get response body
235        let body = response.bytes()?;
236        // Parse JSON response
237        let response: TraceSpansResponse = serde_json::from_slice(&body)?;
238        Ok(response)
239    }
240
241    fn get_trace_metrics(
242        &self,
243        request: TraceMetricsRequest,
244    ) -> Result<TraceMetricsResponse, ClientError> {
245        let query_string = serde_qs::to_string(&request)?;
246        let response = self.client.request(
247            Routes::TraceMetrics,
248            RequestType::Get,
249            None,
250            Some(query_string),
251            None,
252        )?;
253        if !response.status().is_success() {
254            error!(
255                "Failed to get trace metrics. Status: {:?}",
256                response.status()
257            );
258            return Err(ClientError::GetTraceMetricsError);
259        }
260
261        // Get response body
262        let body = response.bytes()?;
263        // Parse JSON response
264        let response: TraceMetricsResponse = serde_json::from_slice(&body)?;
265        Ok(response)
266    }
267
268    fn get_trace_baggage(&self, trace_id: &str) -> Result<TraceBaggageResponse, ClientError> {
269        let trace_request = TraceRequest {
270            trace_id: trace_id.to_string(),
271        };
272        let query_string = serde_qs::to_string(&trace_request)?;
273        let response = self.client.request(
274            Routes::TraceBaggage,
275            RequestType::Get,
276            None,
277            Some(query_string),
278            None,
279        )?;
280        if !response.status().is_success() {
281            error!(
282                "Failed to get trace baggage. Status: {:?}",
283                response.status()
284            );
285            return Err(ClientError::GetTraceBaggageError);
286        }
287
288        // Get response body
289        let body = response.bytes()?;
290        // Parse JSON response
291        let response: TraceBaggageResponse = serde_json::from_slice(&body)?;
292        Ok(response)
293    }
294
295    fn get_tags(&self, tag_request: TagsRequest) -> Result<TagsResponse, ClientError> {
296        let query_string = serde_qs::to_string(&tag_request)?;
297
298        let response = self.client.request(
299            Routes::Tags,
300            RequestType::Get,
301            None,
302            Some(query_string),
303            None,
304        )?;
305
306        if !response.status().is_success() {
307            error!("Failed to get tags. Status: {:?}", response.status());
308            return Err(ClientError::GetTagsError);
309        }
310
311        let body = response.bytes()?;
312
313        let tags_response: TagsResponse = serde_json::from_slice(&body)?;
314
315        Ok(tags_response)
316    }
317}
318
319#[pyclass(name = "ScouterClient")]
320pub struct PyScouterClient {
321    client: ScouterClient,
322}
323#[pymethods]
324impl PyScouterClient {
325    #[new]
326    #[pyo3(signature = (config=None))]
327    pub fn new(config: Option<&Bound<'_, PyAny>>) -> Result<Self, ClientError> {
328        let config = config.map_or(Ok(HttpConfig::default()), |unwrapped| {
329            if unwrapped.is_instance_of::<HttpConfig>() {
330                unwrapped.extract::<HttpConfig>()
331            } else {
332                Err(ClientError::InvalidConfigTypeError.into())
333            }
334        })?;
335
336        let client = ScouterClient::new(Some(config.clone()))?;
337
338        Ok(PyScouterClient { client })
339    }
340
341    /// Insert a profile into the scouter server
342    ///
343    /// # Arguments
344    ///
345    /// * `profile` - A profile object to insert
346    ///
347    /// # Returns
348    ///
349    /// * `Ok(())` if the profile was inserted successfully
350    #[pyo3(signature = (profile, set_active=false, deactivate_others=false))]
351    pub fn register_profile(
352        &self,
353        profile: &Bound<'_, PyAny>,
354        set_active: bool,
355        deactivate_others: bool,
356    ) -> Result<bool, ClientError> {
357        let request = profile
358            .call_method0("create_profile_request")?
359            .extract::<ProfileRequest>()?;
360
361        let profile_response = self.client.insert_profile(&request)?;
362
363        // update config args
364        profile.call_method1(
365            "update_config_args",
366            (
367                Some(profile_response.space),
368                Some(profile_response.name),
369                Some(profile_response.version),
370            ),
371        )?;
372
373        debug!("Profile inserted successfully");
374        if set_active {
375            let name = profile
376                .getattr("config")?
377                .getattr("name")?
378                .extract::<String>()?;
379
380            let space = profile
381                .getattr("config")?
382                .getattr("space")?
383                .extract::<String>()?;
384
385            let version = profile
386                .getattr("config")?
387                .getattr("version")?
388                .extract::<String>()?;
389
390            let drift_type = profile
391                .getattr("config")?
392                .getattr("drift_type")?
393                .extract::<DriftType>()?;
394
395            let request = ProfileStatusRequest {
396                name,
397                space,
398                version,
399                active: true,
400                drift_type: Some(drift_type),
401                deactivate_others,
402            };
403
404            self.client.update_profile_status(&request)?;
405        }
406
407        Ok(true)
408    }
409
410    /// Update the status of a profile
411    ///
412    /// # Arguments
413    /// * `request` - A profile status request object
414    ///
415    /// # Returns
416    /// * `Ok(())` if the profile status was updated successfully
417    pub fn update_profile_status(
418        &self,
419        request: ProfileStatusRequest,
420    ) -> Result<bool, ClientError> {
421        self.client.update_profile_status(&request)
422    }
423
424    /// Get binned drift data from the scouter server
425    ///
426    /// # Arguments
427    ///
428    /// * `drift_request` - A drift request object
429    ///
430    /// # Returns
431    ///
432    /// * A binned drift object
433    pub fn get_binned_drift<'py>(
434        &self,
435        py: Python<'py>,
436        drift_request: DriftRequest,
437    ) -> Result<Bound<'py, PyAny>, ClientError> {
438        match drift_request.drift_type {
439            DriftType::Spc => {
440                PyScouterClient::get_spc_binned_drift(py, &self.client.client, drift_request)
441            }
442            DriftType::Psi => {
443                PyScouterClient::get_psi_binned_drift(py, &self.client.client, drift_request)
444            }
445            DriftType::Custom => {
446                PyScouterClient::get_custom_binned_drift(py, &self.client.client, drift_request)
447            }
448            DriftType::LLM => {
449                PyScouterClient::get_llm_metric_binned_drift(py, &self.client.client, drift_request)
450            }
451        }
452    }
453
454    pub fn get_alerts(&self, request: DriftAlertRequest) -> Result<Vec<Alert>, ClientError> {
455        debug!("Getting alerts for: {:?}", request);
456
457        let alerts = self.client.get_alerts(&request)?;
458
459        Ok(alerts)
460    }
461
462    #[pyo3(signature = (request, path))]
463    pub fn download_profile(
464        &self,
465        request: GetProfileRequest,
466        path: Option<PathBuf>,
467    ) -> Result<String, ClientError> {
468        debug!("Downloading profile: {:?}", request);
469
470        let filename = format!(
471            "{}_{}_{}_{}.json",
472            request.name, request.space, request.version, request.drift_type
473        );
474
475        let profile = self.client.get_drift_profile(request)?;
476
477        PyHelperFuncs::save_to_json(profile, path.clone(), &filename)?;
478
479        Ok(path.map_or(filename, |p| p.to_string_lossy().to_string()))
480    }
481
482    /// Get paginated traces from the scouter server
483    /// # Arguments
484    /// * `filters` - A trace filters object
485    /// # Returns
486    /// * A trace pagination response object
487    pub fn get_paginated_traces(
488        &self,
489        filters: TraceFilters,
490    ) -> Result<TracePaginationResponse, ClientError> {
491        self.client.get_paginated_traces(&filters)
492    }
493
494    /// Refresh the trace summary on the scouter server
495    pub fn refresh_trace_summary(&self) -> Result<bool, ClientError> {
496        self.client.refresh_trace_summary()
497    }
498
499    /// Get trace spans for a given trace ID
500    /// # Arguments
501    /// * `trace_id` - The ID of the trace
502    /// # Returns
503    /// * A trace spans response object
504    pub fn get_trace_spans(&self, trace_id: &str) -> Result<TraceSpansResponse, ClientError> {
505        self.client.get_trace_spans(trace_id)
506    }
507
508    /// Get trace metrics for a given trace metrics request
509    /// # Arguments
510    /// * `trace_id` - The ID of the trace
511    /// # Returns
512    /// * A trace baggage response object
513    pub fn get_trace_baggage(&self, trace_id: &str) -> Result<TraceBaggageResponse, ClientError> {
514        self.client.get_trace_baggage(trace_id)
515    }
516
517    /// Get trace metrics for a given trace metrics request
518    /// # Arguments
519    /// * `request` - A trace metrics request object
520    /// # Returns
521    /// * A trace metrics response object
522    pub fn get_trace_metrics(
523        &self,
524        request: TraceMetricsRequest,
525    ) -> Result<TraceMetricsResponse, ClientError> {
526        self.client.get_trace_metrics(request)
527    }
528
529    /// Get tags for a given entity type and ID
530    /// # Arguments
531    /// * `entity_type` - The type of the entity
532    /// * `entity_id` - The ID of the entity
533    /// # Returns
534    /// * A tags response object
535    pub fn get_tags(
536        &self,
537        entity_type: String,
538        entity_id: String,
539    ) -> Result<TagsResponse, ClientError> {
540        let tag_request = TagsRequest {
541            entity_type,
542            entity_id,
543        };
544        self.client.get_tags(tag_request)
545    }
546}
547
548impl PyScouterClient {
549    fn get_spc_binned_drift<'py>(
550        py: Python<'py>,
551        client: &HttpClient,
552        drift_request: DriftRequest,
553    ) -> Result<Bound<'py, PyAny>, ClientError> {
554        let query_string = serde_qs::to_string(&drift_request)?;
555
556        let response = client.request(
557            Routes::SpcDrift,
558            RequestType::Get,
559            None,
560            Some(query_string),
561            None,
562        )?;
563
564        if response.status().is_client_error() || response.status().is_server_error() {
565            return Err(ClientError::GetDriftDataError);
566        }
567
568        let body = response.bytes()?;
569
570        let results: SpcDriftFeatures = serde_json::from_slice(&body)?;
571
572        Ok(results.into_bound_py_any(py).unwrap())
573    }
574    fn get_psi_binned_drift<'py>(
575        py: Python<'py>,
576        client: &HttpClient,
577        drift_request: DriftRequest,
578    ) -> Result<Bound<'py, PyAny>, ClientError> {
579        let query_string = serde_qs::to_string(&drift_request)?;
580
581        let response = client.request(
582            Routes::PsiDrift,
583            RequestType::Get,
584            None,
585            Some(query_string),
586            None,
587        )?;
588
589        if response.status().is_client_error() || response.status().is_server_error() {
590            // print response text
591            error!(
592                "Failed to get PSI drift data. Status: {:?}",
593                response.status()
594            );
595            error!("Response text: {:?}", response.text());
596            return Err(ClientError::GetDriftDataError);
597        }
598
599        let body = response.bytes()?;
600
601        let results: BinnedPsiFeatureMetrics = serde_json::from_slice(&body)?;
602
603        Ok(results.into_bound_py_any(py).unwrap())
604    }
605
606    fn get_custom_binned_drift<'py>(
607        py: Python<'py>,
608        client: &HttpClient,
609        drift_request: DriftRequest,
610    ) -> Result<Bound<'py, PyAny>, ClientError> {
611        let query_string = serde_qs::to_string(&drift_request)?;
612
613        let response = client.request(
614            Routes::CustomDrift,
615            RequestType::Get,
616            None,
617            Some(query_string),
618            None,
619        )?;
620
621        if response.status().is_client_error() || response.status().is_server_error() {
622            return Err(ClientError::GetDriftDataError);
623        }
624
625        let body = response.bytes()?;
626
627        let results: BinnedMetrics = serde_json::from_slice(&body)?;
628
629        Ok(results.into_bound_py_any(py).unwrap())
630    }
631
632    fn get_llm_metric_binned_drift<'py>(
633        py: Python<'py>,
634        client: &HttpClient,
635        drift_request: DriftRequest,
636    ) -> Result<Bound<'py, PyAny>, ClientError> {
637        let query_string = serde_qs::to_string(&drift_request)?;
638
639        let response = client.request(
640            Routes::LLMDrift,
641            RequestType::Get,
642            None,
643            Some(query_string),
644            None,
645        )?;
646
647        if response.status().is_client_error() || response.status().is_server_error() {
648            return Err(ClientError::GetDriftDataError);
649        }
650
651        let body = response.bytes()?;
652
653        let results: BinnedMetrics = serde_json::from_slice(&body)?;
654
655        Ok(results.into_bound_py_any(py).unwrap())
656    }
657}