use std::time::Duration;
use super::error::BatutaError;
use super::pricing::{FallbackPricing, GpuPricing};
use super::queue::QueueState;
#[derive(Debug, Clone)]
pub struct BatutaClient {
base_url: Option<String>,
fallback: FallbackPricing,
timeout: Duration,
service_available: bool,
}
impl Default for BatutaClient {
fn default() -> Self {
Self::new()
}
}
impl BatutaClient {
pub fn new() -> Self {
Self {
base_url: None,
fallback: FallbackPricing::new(),
timeout: Duration::from_secs(5),
service_available: false,
}
}
pub fn with_url(url: impl Into<String>) -> Self {
Self {
base_url: Some(url.into()),
fallback: FallbackPricing::new(),
timeout: Duration::from_secs(5),
service_available: true,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_fallback(mut self, fallback: FallbackPricing) -> Self {
self.fallback = fallback;
self
}
pub fn is_connected(&self) -> bool {
self.base_url.is_some() && self.service_available
}
pub fn get_hourly_rate(&self, gpu_type: &str) -> Result<GpuPricing, BatutaError> {
if let Some(_url) = &self.base_url {
}
self.fallback
.get_rate(gpu_type)
.cloned()
.ok_or_else(|| BatutaError::UnknownGpuType(gpu_type.to_string()))
}
pub fn get_queue_depth(&self, gpu_type: &str) -> Result<QueueState, BatutaError> {
if self.fallback.get_rate(gpu_type).is_none() {
return Err(BatutaError::UnknownGpuType(gpu_type.to_string()));
}
if let Some(_url) = &self.base_url {
}
Ok(QueueState::new(0, 4, 4))
}
pub fn get_status(&self, gpu_type: &str) -> Result<(GpuPricing, QueueState), BatutaError> {
let pricing = self.get_hourly_rate(gpu_type)?;
let queue = self.get_queue_depth(gpu_type)?;
Ok((pricing, queue))
}
pub fn estimate_cost(&self, gpu_type: &str, hours: f64) -> Result<f64, BatutaError> {
let pricing = self.get_hourly_rate(gpu_type)?;
Ok(pricing.hourly_rate * hours)
}
pub fn cheapest_gpu(&self, min_memory_gb: u32) -> Option<&GpuPricing> {
self.fallback.all_pricing().iter().filter(|p| p.memory_gb >= min_memory_gb).min_by(
|a, b| a.hourly_rate.partial_cmp(&b.hourly_rate).unwrap_or(std::cmp::Ordering::Equal),
)
}
}