1#![allow(clippy::useless_conversion)]
2use crate::error::ClientError;
3use pyo3::{prelude::*, IntoPyObjectExt};
4use scouter_settings::http::HttpConfig;
5use scouter_types::contracts::{
6 DriftAlertRequest, DriftRequest, GetProfileRequest, ProfileRequest, ProfileStatusRequest,
7};
8use scouter_types::http::{RequestType, Routes};
9use scouter_types::sql::TraceFilters;
10use scouter_types::{
11 RegisteredProfileResponse, TagsRequest, TagsResponse, TraceBaggageResponse,
12 TraceMetricsRequest, TraceMetricsResponse, TracePaginationResponse, TraceRequest,
13 TraceSpansResponse,
14};
15
16use crate::http::HttpClient;
17use scouter_types::{
18 alert::Alert, psi::BinnedPsiFeatureMetrics, spc::SpcDriftFeatures, BinnedMetrics, DriftProfile,
19 DriftType, PyHelperFuncs,
20};
21use std::path::PathBuf;
22use tracing::{debug, error};
23
24pub const DOWNLOAD_CHUNK_SIZE: usize = 1024 * 1024 * 5;
25
26#[derive(Debug, Clone)]
27pub struct ScouterClient {
28 client: HttpClient,
29}
30
31impl ScouterClient {
32 pub fn new(config: Option<HttpConfig>) -> Result<Self, ClientError> {
33 let client = HttpClient::new(config.unwrap_or_default())?;
34
35 Ok(ScouterClient { client })
36 }
37
38 pub fn insert_profile(
40 &self,
41 request: &ProfileRequest,
42 ) -> Result<RegisteredProfileResponse, ClientError> {
43 let response = self.client.request(
44 Routes::Profile,
45 RequestType::Post,
46 Some(serde_json::to_value(request).unwrap()),
47 None,
48 None,
49 )?;
50
51 if response.status().is_success() {
52 let body = response.bytes()?;
53 let profile_response: RegisteredProfileResponse = serde_json::from_slice(&body)?;
54
55 debug!("Profile inserted successfully: {:?}", profile_response);
56 Ok(profile_response)
57 } else {
58 Err(ClientError::InsertProfileError)
59 }
60 }
61
62 pub fn update_profile_status(
63 &self,
64 request: &ProfileStatusRequest,
65 ) -> Result<bool, ClientError> {
66 let response = self.client.request(
67 Routes::ProfileStatus,
68 RequestType::Put,
69 Some(serde_json::to_value(request).unwrap()),
70 None,
71 None,
72 )?;
73
74 if response.status().is_success() {
75 Ok(true)
76 } else {
77 Err(ClientError::UpdateProfileError)
78 }
79 }
80
81 pub fn get_alerts(&self, request: &DriftAlertRequest) -> Result<Vec<Alert>, ClientError> {
82 debug!("Getting alerts for: {:?}", request);
83
84 let query_string = serde_qs::to_string(request)?;
85
86 let response = self.client.request(
87 Routes::Alerts,
88 RequestType::Get,
89 None,
90 Some(query_string),
91 None,
92 )?;
93
94 if !response.status().is_success() {
96 return Err(ClientError::GetDriftAlertError);
97 }
98
99 let body: serde_json::Value = response.json()?;
101
102 let alerts = body
104 .get("alerts")
105 .map(|alerts| {
106 serde_json::from_value::<Vec<Alert>>(alerts.clone()).inspect_err(|e| {
107 error!(
108 "Failed to parse drift alerts {:?}. Error: {:?}",
109 &request, e
110 );
111 })
112 })
113 .unwrap_or_else(|| {
114 error!("No alerts found in response");
115 Ok(Vec::new())
116 })?;
117
118 Ok(alerts)
119 }
120
121 pub fn get_drift_profile(
122 &self,
123 request: GetProfileRequest,
124 ) -> Result<DriftProfile, ClientError> {
125 let query_string = serde_qs::to_string(&request)?;
126
127 let response = self.client.request(
128 Routes::Profile,
129 RequestType::Get,
130 None,
131 Some(query_string),
132 None,
133 )?;
134
135 if !response.status().is_success() {
137 error!("Failed to get profile. Status: {:?}", response.status());
138 return Err(ClientError::GetDriftProfileError);
139 }
140
141 let body = response.bytes()?;
143
144 let profile: DriftProfile = serde_json::from_slice(&body)?;
146
147 Ok(profile)
148 }
149
150 pub fn check_service_health(&self) -> Result<bool, ClientError> {
152 let response = self
153 .client
154 .request(Routes::Healthcheck, RequestType::Get, None, None, None)
155 .inspect_err(|e| {
156 error!("Failed to check scouter health {}", e);
157 })?;
158
159 if response.status() == 200 {
160 Ok(true)
161 } else {
162 Ok(false)
163 }
164 }
165
166 fn get_paginated_traces(
167 &self,
168 request: &TraceFilters,
169 ) -> Result<TracePaginationResponse, ClientError> {
170 let response = self.client.request(
171 Routes::PaginatedTraces,
172 RequestType::Post,
173 Some(serde_json::to_value(request).unwrap()),
174 None,
175 None,
176 )?;
177
178 if !response.status().is_success() {
179 let status_code = response.status();
180 let err_msg = response.text().unwrap_or_default();
181 error!(
182 "Failed to get paginated traces. Status: {:?}, Error: {}",
183 status_code, err_msg
184 );
185 return Err(ClientError::GetPaginatedTracesError);
186 }
187
188 let body = response.bytes()?;
190
191 let response: TracePaginationResponse = serde_json::from_slice(&body)?;
193 Ok(response)
194 }
195
196 fn refresh_trace_summary(&self) -> Result<bool, ClientError> {
197 let response = self.client.request(
198 Routes::RefreshTraceSummary,
199 RequestType::Get,
200 None,
201 None,
202 None,
203 )?;
204 if !response.status().is_success() {
205 error!(
206 "Failed to refresh trace summary. Status: {:?}",
207 response.status()
208 );
209 return Err(ClientError::RefreshTraceSummaryError);
210 }
211
212 Ok(true)
213 }
214
215 fn get_trace_spans(&self, trace_id: &str) -> Result<TraceSpansResponse, ClientError> {
216 let trace_request = TraceRequest {
217 trace_id: trace_id.to_string(),
218 };
219
220 let query_string = serde_qs::to_string(&trace_request)?;
221
222 let response = self.client.request(
223 Routes::TraceSpans,
224 RequestType::Get,
225 None,
226 Some(query_string),
227 None,
228 )?;
229 if !response.status().is_success() {
230 error!("Failed to get trace spans. Status: {:?}", response.status());
231 return Err(ClientError::GetTraceSpansError);
232 }
233
234 let body = response.bytes()?;
236 let response: TraceSpansResponse = serde_json::from_slice(&body)?;
238 Ok(response)
239 }
240
241 fn get_trace_metrics(
242 &self,
243 request: TraceMetricsRequest,
244 ) -> Result<TraceMetricsResponse, ClientError> {
245 let query_string = serde_qs::to_string(&request)?;
246 let response = self.client.request(
247 Routes::TraceMetrics,
248 RequestType::Get,
249 None,
250 Some(query_string),
251 None,
252 )?;
253 if !response.status().is_success() {
254 error!(
255 "Failed to get trace metrics. Status: {:?}",
256 response.status()
257 );
258 return Err(ClientError::GetTraceMetricsError);
259 }
260
261 let body = response.bytes()?;
263 let response: TraceMetricsResponse = serde_json::from_slice(&body)?;
265 Ok(response)
266 }
267
268 fn get_trace_baggage(&self, trace_id: &str) -> Result<TraceBaggageResponse, ClientError> {
269 let trace_request = TraceRequest {
270 trace_id: trace_id.to_string(),
271 };
272 let query_string = serde_qs::to_string(&trace_request)?;
273 let response = self.client.request(
274 Routes::TraceBaggage,
275 RequestType::Get,
276 None,
277 Some(query_string),
278 None,
279 )?;
280 if !response.status().is_success() {
281 error!(
282 "Failed to get trace baggage. Status: {:?}",
283 response.status()
284 );
285 return Err(ClientError::GetTraceBaggageError);
286 }
287
288 let body = response.bytes()?;
290 let response: TraceBaggageResponse = serde_json::from_slice(&body)?;
292 Ok(response)
293 }
294
295 fn get_tags(&self, tag_request: TagsRequest) -> Result<TagsResponse, ClientError> {
296 let query_string = serde_qs::to_string(&tag_request)?;
297
298 let response = self.client.request(
299 Routes::Tags,
300 RequestType::Get,
301 None,
302 Some(query_string),
303 None,
304 )?;
305
306 if !response.status().is_success() {
307 error!("Failed to get tags. Status: {:?}", response.status());
308 return Err(ClientError::GetTagsError);
309 }
310
311 let body = response.bytes()?;
312
313 let tags_response: TagsResponse = serde_json::from_slice(&body)?;
314
315 Ok(tags_response)
316 }
317}
318
319#[pyclass(name = "ScouterClient")]
320pub struct PyScouterClient {
321 client: ScouterClient,
322}
323#[pymethods]
324impl PyScouterClient {
325 #[new]
326 #[pyo3(signature = (config=None))]
327 pub fn new(config: Option<&Bound<'_, PyAny>>) -> Result<Self, ClientError> {
328 let config = config.map_or(Ok(HttpConfig::default()), |unwrapped| {
329 if unwrapped.is_instance_of::<HttpConfig>() {
330 unwrapped.extract::<HttpConfig>()
331 } else {
332 Err(ClientError::InvalidConfigTypeError.into())
333 }
334 })?;
335
336 let client = ScouterClient::new(Some(config.clone()))?;
337
338 Ok(PyScouterClient { client })
339 }
340
341 #[pyo3(signature = (profile, set_active=false, deactivate_others=false))]
351 pub fn register_profile(
352 &self,
353 profile: &Bound<'_, PyAny>,
354 set_active: bool,
355 deactivate_others: bool,
356 ) -> Result<bool, ClientError> {
357 let request = profile
358 .call_method0("create_profile_request")?
359 .extract::<ProfileRequest>()?;
360
361 let profile_response = self.client.insert_profile(&request)?;
362
363 profile.call_method1(
365 "update_config_args",
366 (
367 Some(profile_response.space),
368 Some(profile_response.name),
369 Some(profile_response.version),
370 ),
371 )?;
372
373 debug!("Profile inserted successfully");
374 if set_active {
375 let name = profile
376 .getattr("config")?
377 .getattr("name")?
378 .extract::<String>()?;
379
380 let space = profile
381 .getattr("config")?
382 .getattr("space")?
383 .extract::<String>()?;
384
385 let version = profile
386 .getattr("config")?
387 .getattr("version")?
388 .extract::<String>()?;
389
390 let drift_type = profile
391 .getattr("config")?
392 .getattr("drift_type")?
393 .extract::<DriftType>()?;
394
395 let request = ProfileStatusRequest {
396 name,
397 space,
398 version,
399 active: true,
400 drift_type: Some(drift_type),
401 deactivate_others,
402 };
403
404 self.client.update_profile_status(&request)?;
405 }
406
407 Ok(true)
408 }
409
410 pub fn update_profile_status(
418 &self,
419 request: ProfileStatusRequest,
420 ) -> Result<bool, ClientError> {
421 self.client.update_profile_status(&request)
422 }
423
424 pub fn get_binned_drift<'py>(
434 &self,
435 py: Python<'py>,
436 drift_request: DriftRequest,
437 ) -> Result<Bound<'py, PyAny>, ClientError> {
438 match drift_request.drift_type {
439 DriftType::Spc => {
440 PyScouterClient::get_spc_binned_drift(py, &self.client.client, drift_request)
441 }
442 DriftType::Psi => {
443 PyScouterClient::get_psi_binned_drift(py, &self.client.client, drift_request)
444 }
445 DriftType::Custom => {
446 PyScouterClient::get_custom_binned_drift(py, &self.client.client, drift_request)
447 }
448 DriftType::LLM => {
449 PyScouterClient::get_llm_metric_binned_drift(py, &self.client.client, drift_request)
450 }
451 }
452 }
453
454 pub fn get_alerts(&self, request: DriftAlertRequest) -> Result<Vec<Alert>, ClientError> {
455 debug!("Getting alerts for: {:?}", request);
456
457 let alerts = self.client.get_alerts(&request)?;
458
459 Ok(alerts)
460 }
461
462 #[pyo3(signature = (request, path))]
463 pub fn download_profile(
464 &self,
465 request: GetProfileRequest,
466 path: Option<PathBuf>,
467 ) -> Result<String, ClientError> {
468 debug!("Downloading profile: {:?}", request);
469
470 let filename = format!(
471 "{}_{}_{}_{}.json",
472 request.name, request.space, request.version, request.drift_type
473 );
474
475 let profile = self.client.get_drift_profile(request)?;
476
477 PyHelperFuncs::save_to_json(profile, path.clone(), &filename)?;
478
479 Ok(path.map_or(filename, |p| p.to_string_lossy().to_string()))
480 }
481
482 pub fn get_paginated_traces(
488 &self,
489 filters: TraceFilters,
490 ) -> Result<TracePaginationResponse, ClientError> {
491 self.client.get_paginated_traces(&filters)
492 }
493
494 pub fn refresh_trace_summary(&self) -> Result<bool, ClientError> {
496 self.client.refresh_trace_summary()
497 }
498
499 pub fn get_trace_spans(&self, trace_id: &str) -> Result<TraceSpansResponse, ClientError> {
505 self.client.get_trace_spans(trace_id)
506 }
507
508 pub fn get_trace_baggage(&self, trace_id: &str) -> Result<TraceBaggageResponse, ClientError> {
514 self.client.get_trace_baggage(trace_id)
515 }
516
517 pub fn get_trace_metrics(
523 &self,
524 request: TraceMetricsRequest,
525 ) -> Result<TraceMetricsResponse, ClientError> {
526 self.client.get_trace_metrics(request)
527 }
528
529 pub fn get_tags(
536 &self,
537 entity_type: String,
538 entity_id: String,
539 ) -> Result<TagsResponse, ClientError> {
540 let tag_request = TagsRequest {
541 entity_type,
542 entity_id,
543 };
544 self.client.get_tags(tag_request)
545 }
546}
547
548impl PyScouterClient {
549 fn get_spc_binned_drift<'py>(
550 py: Python<'py>,
551 client: &HttpClient,
552 drift_request: DriftRequest,
553 ) -> Result<Bound<'py, PyAny>, ClientError> {
554 let query_string = serde_qs::to_string(&drift_request)?;
555
556 let response = client.request(
557 Routes::SpcDrift,
558 RequestType::Get,
559 None,
560 Some(query_string),
561 None,
562 )?;
563
564 if response.status().is_client_error() || response.status().is_server_error() {
565 return Err(ClientError::GetDriftDataError);
566 }
567
568 let body = response.bytes()?;
569
570 let results: SpcDriftFeatures = serde_json::from_slice(&body)?;
571
572 Ok(results.into_bound_py_any(py).unwrap())
573 }
574 fn get_psi_binned_drift<'py>(
575 py: Python<'py>,
576 client: &HttpClient,
577 drift_request: DriftRequest,
578 ) -> Result<Bound<'py, PyAny>, ClientError> {
579 let query_string = serde_qs::to_string(&drift_request)?;
580
581 let response = client.request(
582 Routes::PsiDrift,
583 RequestType::Get,
584 None,
585 Some(query_string),
586 None,
587 )?;
588
589 if response.status().is_client_error() || response.status().is_server_error() {
590 error!(
592 "Failed to get PSI drift data. Status: {:?}",
593 response.status()
594 );
595 error!("Response text: {:?}", response.text());
596 return Err(ClientError::GetDriftDataError);
597 }
598
599 let body = response.bytes()?;
600
601 let results: BinnedPsiFeatureMetrics = serde_json::from_slice(&body)?;
602
603 Ok(results.into_bound_py_any(py).unwrap())
604 }
605
606 fn get_custom_binned_drift<'py>(
607 py: Python<'py>,
608 client: &HttpClient,
609 drift_request: DriftRequest,
610 ) -> Result<Bound<'py, PyAny>, ClientError> {
611 let query_string = serde_qs::to_string(&drift_request)?;
612
613 let response = client.request(
614 Routes::CustomDrift,
615 RequestType::Get,
616 None,
617 Some(query_string),
618 None,
619 )?;
620
621 if response.status().is_client_error() || response.status().is_server_error() {
622 return Err(ClientError::GetDriftDataError);
623 }
624
625 let body = response.bytes()?;
626
627 let results: BinnedMetrics = serde_json::from_slice(&body)?;
628
629 Ok(results.into_bound_py_any(py).unwrap())
630 }
631
632 fn get_llm_metric_binned_drift<'py>(
633 py: Python<'py>,
634 client: &HttpClient,
635 drift_request: DriftRequest,
636 ) -> Result<Bound<'py, PyAny>, ClientError> {
637 let query_string = serde_qs::to_string(&drift_request)?;
638
639 let response = client.request(
640 Routes::LLMDrift,
641 RequestType::Get,
642 None,
643 Some(query_string),
644 None,
645 )?;
646
647 if response.status().is_client_error() || response.status().is_server_error() {
648 return Err(ClientError::GetDriftDataError);
649 }
650
651 let body = response.bytes()?;
652
653 let results: BinnedMetrics = serde_json::from_slice(&body)?;
654
655 Ok(results.into_bound_py_any(py).unwrap())
656 }
657}