use super::completion::{CompletionRequest, CompletionResponse};
use super::config::{CloudBackend, CloudConfig};
use super::error::CloudError;
use crate::http::{with_retry, CircuitBreaker, CircuitConfig, RetryPolicy, RetryResult};
use serde_json::json;
use std::sync::Arc;
use std::time::{Duration, Instant};
const DEFAULT_CONNECT_TIMEOUT_MS: u64 = 10_000;
pub struct Cloud {
config: CloudConfig,
agent: ureq::Agent,
gateway_circuit: Arc<CircuitBreaker>,
retry_policy: RetryPolicy,
}
impl Cloud {
pub fn new() -> Result<Self, CloudError> {
Self::with_config(CloudConfig::default())
}
pub fn with_config(config: CloudConfig) -> Result<Self, CloudError> {
let agent = ureq::AgentBuilder::new()
.timeout_connect(Duration::from_millis(DEFAULT_CONNECT_TIMEOUT_MS))
.timeout(Duration::from_millis(config.timeout_ms as u64))
.build();
let gateway_circuit = Arc::new(CircuitBreaker::new(CircuitConfig::default()));
let retry_policy = RetryPolicy::conservative();
Ok(Self {
config,
agent,
gateway_circuit,
retry_policy,
})
}
pub fn gateway() -> Result<Self, CloudError> {
Self::with_config(CloudConfig::gateway())
}
pub fn direct(provider: &str) -> Result<Self, CloudError> {
Self::with_config(CloudConfig::direct(provider))
}
pub fn is_circuit_open(&self) -> bool {
self.gateway_circuit.is_open()
}
pub fn reset_circuit(&self) {
self.gateway_circuit.reset();
}
pub fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, CloudError> {
let start = Instant::now();
let mut response = match self.config.backend {
CloudBackend::Gateway => self.complete_via_gateway(request)?,
CloudBackend::Direct => self.call_direct(request)?,
};
response.latency_ms = Some(start.elapsed().as_millis() as u32);
Ok(response)
}
fn complete_via_gateway(
&self,
request: CompletionRequest,
) -> Result<CompletionResponse, CloudError> {
if !self.gateway_circuit.can_execute() {
return Err(CloudError::CircuitOpen(
"Gateway circuit breaker is open due to recent failures. Try again later.".into(),
));
}
let request_clone = request.clone();
let result: RetryResult<CompletionResponse, CloudError> =
with_retry(&self.retry_policy, Some(&self.gateway_circuit), || {
self.call_gateway(request_clone.clone())
});
result.into_result()
}
pub fn prompt(&self, prompt: &str) -> Result<String, CloudError> {
let request = CompletionRequest::new(prompt);
let response = self.complete(request)?;
Ok(response.text)
}
pub fn chat(&self, system: &str, user_message: &str) -> Result<String, CloudError> {
let request = CompletionRequest::new(user_message).with_system(system);
let response = self.complete(request)?;
Ok(response.text)
}
fn call_gateway(&self, request: CompletionRequest) -> Result<CompletionResponse, CloudError> {
let api_key = self.config.resolve_api_key();
let messages = request.to_messages();
let openai_messages: Vec<serde_json::Value> = messages
.iter()
.map(|m| {
json!({
"role": match m.role {
super::completion::Role::System => "system",
super::completion::Role::User => "user",
super::completion::Role::Assistant => "assistant",
},
"content": &m.content
})
})
.collect();
let model = request
.model
.clone()
.or_else(|| self.config.default_model.clone())
.unwrap_or_else(|| "gpt-4o-mini".to_string());
let mut body = json!({
"model": model,
"messages": openai_messages,
});
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = json!(max_tokens);
}
if let Some(temperature) = request.temperature {
body["temperature"] = json!(temperature);
}
if let Some(top_p) = request.top_p {
body["top_p"] = json!(top_p);
}
if let Some(ref stop) = request.stop {
body["stop"] = json!(stop);
}
if request.stream {
body["stream"] = json!(true);
}
let url = format!("{}/chat/completions", self.config.gateway_url);
if self.config.debug {
eprintln!("[Cloud] Gateway request to: {}", url);
eprintln!(
"[Cloud] Body: {}",
serde_json::to_string_pretty(&body).unwrap_or_default()
);
}
let mut req = self
.agent
.post(&url)
.set("Content-Type", "application/json");
if let Some(ref key) = api_key {
req = req.set("Authorization", &format!("Bearer {}", key));
}
let response = req.send_json(&body);
match response {
Ok(resp) => {
let json_resp: serde_json::Value = resp
.into_json()
.map_err(|e| CloudError::ParseError(e.to_string()))?;
if self.config.debug {
eprintln!(
"[Cloud] Response: {}",
serde_json::to_string_pretty(&json_resp).unwrap_or_default()
);
}
let text = json_resp["choices"][0]["message"]["content"]
.as_str()
.unwrap_or("")
.to_string();
let model = json_resp["model"].as_str().unwrap_or("unknown").to_string();
let finish_reason = json_resp["choices"][0]["finish_reason"]
.as_str()
.map(|s| s.to_string());
let usage = json_resp.get("usage").map(parse_gateway_usage);
let id = json_resp["id"].as_str().map(|s| s.to_string());
Ok(CompletionResponse {
text,
model,
finish_reason,
usage,
id,
latency_ms: None,
backend: Some("gateway".to_string()),
})
}
Err(ureq::Error::Status(status, resp)) => {
let retry_after_secs = resp
.header("Retry-After")
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(60);
let error_body: Result<serde_json::Value, _> = resp.into_json();
let message = error_body
.ok()
.and_then(|v| v["error"]["message"].as_str().map(|s| s.to_string()))
.unwrap_or_else(|| "Unknown error".into());
match status {
429 => Err(CloudError::RateLimited { retry_after_secs }),
502..=504 => Err(CloudError::GatewayError(format!(
"Gateway returned {}: {}",
status, message
))),
_ => Err(CloudError::ApiError { status, message }),
}
}
Err(ureq::Error::Transport(transport)) => {
let msg = transport.to_string();
if msg.contains("timed out") || msg.contains("timeout") {
return Err(CloudError::Timeout {
timeout_ms: self.config.timeout_ms,
});
}
Err(CloudError::NetworkError(msg))
}
}
}
fn call_direct(&self, request: CompletionRequest) -> Result<CompletionResponse, CloudError> {
let provider =
self.config.direct_provider.as_ref().ok_or_else(|| {
CloudError::ConfigError("Direct provider not configured".to_string())
})?;
let llm_provider: crate::pipeline::IntegrationProvider = provider
.parse()
.map_err(|e: String| CloudError::ConfigError(e))?;
let client = crate::cloud_llm::LlmClient::new(llm_provider)?;
let llm_request: crate::cloud_llm::LlmRequest = request.into();
let response = client.complete(llm_request)?;
let mut completion_response: CompletionResponse = response.into();
completion_response.backend = Some(format!("direct:{}", provider));
Ok(completion_response)
}
pub fn config(&self) -> &CloudConfig {
&self.config
}
}
impl Default for Cloud {
fn default() -> Self {
Self::new().expect("Failed to create default Cloud client")
}
}
pub(crate) fn parse_gateway_usage(u: &serde_json::Value) -> super::completion::Usage {
let has_cache_fields =
u.get("prompt_cache_hit_tokens").is_some() || u.get("prompt_cache_miss_tokens").is_some();
let cache_read_input_tokens = if has_cache_fields {
Some(
u.get("prompt_cache_hit_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32,
)
} else {
None
};
super::completion::Usage {
prompt_tokens: u["prompt_tokens"].as_u64().unwrap_or(0) as u32,
completion_tokens: u["completion_tokens"].as_u64().unwrap_or(0) as u32,
total_tokens: u["total_tokens"].as_u64().unwrap_or(0) as u32,
cache_read_input_tokens,
cache_creation_input_tokens: None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cloud_new() {
let cloud = Cloud::new();
assert!(cloud.is_ok());
let cloud = cloud.unwrap();
assert_eq!(cloud.config().backend, CloudBackend::Gateway);
}
#[test]
fn test_cloud_with_config() {
let config = CloudConfig::gateway()
.with_default_model("gpt-4o-mini")
.with_timeout(60000);
let cloud = Cloud::with_config(config).unwrap();
assert_eq!(
cloud.config().default_model,
Some("gpt-4o-mini".to_string())
);
assert_eq!(cloud.config().timeout_ms, 60000);
}
#[test]
fn test_cloud_direct() {
std::env::set_var("OPENAI_API_KEY", "test");
let cloud = Cloud::direct("openai");
assert!(cloud.is_ok());
std::env::remove_var("OPENAI_API_KEY");
}
#[test]
fn test_cloud_circuit_breaker_initial_state() {
let cloud = Cloud::new().unwrap();
assert!(!cloud.is_circuit_open());
}
#[test]
fn gateway_usage_maps_deepseek_cache_fields() {
let blob = serde_json::json!({
"prompt_tokens": 1000,
"completion_tokens": 120,
"total_tokens": 1120,
"prompt_cache_hit_tokens": 800,
"prompt_cache_miss_tokens": 200,
});
let usage = parse_gateway_usage(&blob);
assert_eq!(usage.prompt_tokens, 1000);
assert_eq!(usage.completion_tokens, 120);
assert_eq!(usage.cache_read_input_tokens, Some(800));
assert_eq!(usage.cache_creation_input_tokens, None);
let derived_uncached = usage
.prompt_tokens
.saturating_sub(usage.cache_read_input_tokens.unwrap_or(0))
.saturating_sub(usage.cache_creation_input_tokens.unwrap_or(0));
assert_eq!(derived_uncached, 200);
}
#[test]
fn gateway_usage_cold_cache_only_miss_field() {
let blob = serde_json::json!({
"prompt_tokens": 500,
"completion_tokens": 50,
"total_tokens": 550,
"prompt_cache_miss_tokens": 500,
});
let usage = parse_gateway_usage(&blob);
assert_eq!(usage.cache_read_input_tokens, Some(0));
assert_eq!(usage.cache_creation_input_tokens, None);
}
#[test]
fn gateway_usage_no_cache_fields_stays_none() {
let blob = serde_json::json!({
"prompt_tokens": 300,
"completion_tokens": 30,
"total_tokens": 330,
});
let usage = parse_gateway_usage(&blob);
assert_eq!(usage.cache_read_input_tokens, None);
assert_eq!(usage.cache_creation_input_tokens, None);
}
#[test]
fn test_cloud_circuit_breaker_reset() {
let cloud = Cloud::new().unwrap();
for _ in 0..3 {
cloud.gateway_circuit.record_failure();
}
assert!(cloud.is_circuit_open());
cloud.reset_circuit();
assert!(!cloud.is_circuit_open());
}
#[test]
fn test_cloud_error_circuit_open() {
let err = CloudError::CircuitOpen("test".to_string());
assert!(matches!(err, CloudError::CircuitOpen(_)));
assert_eq!(err.to_string(), "Circuit breaker open: test");
}
#[test]
fn test_circuit_open_not_retryable() {
use crate::http::RetryableError;
let err = CloudError::CircuitOpen("test".to_string());
assert!(!err.is_retryable());
}
}