use super::{BaseProvider, ModelPricing, Provider, ProviderError, ProviderType};
use crate::config::ProviderConfig;
use crate::core::models::{RequestContext, openai::*};
use crate::utils::error::Result;
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use tracing::{debug, info};
#[derive(Debug, Clone)]
pub struct StabilityAIProvider {
base: BaseProvider,
pricing_cache: HashMap<String, ModelPricing>,
}
impl StabilityAIProvider {
pub async fn new(config: &ProviderConfig) -> Result<Self> {
let base = BaseProvider::new(config)?;
let base_url = config
.base_url
.clone()
.unwrap_or_else(|| "https://api.stability.ai".to_string());
let provider = Self {
base: BaseProvider { base_url, ..base },
pricing_cache: Self::initialize_pricing_cache(),
};
info!(
"Stability AI provider '{}' initialized successfully",
config.name
);
Ok(provider)
}
fn initialize_pricing_cache() -> HashMap<String, ModelPricing> {
let mut cache = HashMap::new();
cache.insert(
"stable-diffusion-xl-1024-v1-0".to_string(),
ModelPricing {
model: "stable-diffusion-xl-1024-v1-0".to_string(),
input_cost_per_1k: 0.0, output_cost_per_1k: 0.04, currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"stable-diffusion-v1-6".to_string(),
ModelPricing {
model: "stable-diffusion-v1-6".to_string(),
input_cost_per_1k: 0.0,
output_cost_per_1k: 0.02, currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"stable-diffusion-512-v2-1".to_string(),
ModelPricing {
model: "stable-diffusion-512-v2-1".to_string(),
input_cost_per_1k: 0.0,
output_cost_per_1k: 0.02,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"stable-diffusion-xl-beta-v2-2-2".to_string(),
ModelPricing {
model: "stable-diffusion-xl-beta-v2-2-2".to_string(),
input_cost_per_1k: 0.0,
output_cost_per_1k: 0.08,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache
}
fn convert_stability_response_to_openai(
&self,
stability_response: serde_json::Value,
) -> Result<ImageGenerationResponse> {
let artifacts = stability_response
.get("artifacts")
.and_then(|a| a.as_array())
.ok_or_else(|| ProviderError::Parsing("No artifacts in response".to_string()))?;
let data: Vec<ImageObject> = artifacts
.iter()
.filter_map(|artifact| {
let base64 = artifact.get("base64")?.as_str()?;
Some(ImageObject {
url: None,
b64_json: Some(base64.to_string()),
})
})
.collect();
if data.is_empty() {
return Err(ProviderError::Parsing("No valid images in response".to_string()).into());
}
Ok(ImageGenerationResponse {
created: chrono::Utc::now().timestamp() as u64,
data,
})
}
}
#[async_trait]
impl Provider for StabilityAIProvider {
fn name(&self) -> &str {
&self.base.name
}
fn provider_type(&self) -> ProviderType {
ProviderType::Custom("stability_ai".to_string())
}
async fn supports_model(&self, model: &str) -> bool {
self.base.is_model_supported(model) || model.contains("stable-diffusion")
}
async fn supports_images(&self) -> bool {
true }
async fn supports_embeddings(&self) -> bool {
false
}
async fn supports_streaming(&self) -> bool {
false
}
async fn list_models(&self) -> Result<Vec<Model>> {
let url = format!("{}/v1/engines/list", self.base.base_url);
let response = self
.base
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.base.api_key))
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
if !response.status().is_success() {
let known_models = vec![
"stable-diffusion-xl-1024-v1-0",
"stable-diffusion-v1-6",
"stable-diffusion-512-v2-1",
"stable-diffusion-xl-beta-v2-2-2",
];
let models = known_models
.into_iter()
.map(|model| Model {
id: model.to_string(),
object: "model".to_string(),
created: chrono::Utc::now().timestamp() as u64,
owned_by: "stability-ai".to_string(),
})
.collect();
return Ok(models);
}
let engines_response: serde_json::Value = self.base.parse_json_response(response).await?;
let models = engines_response
.as_array()
.unwrap_or(&vec![])
.iter()
.filter_map(|engine| {
Some(Model {
id: engine.get("id")?.as_str()?.to_string(),
object: "model".to_string(),
created: chrono::Utc::now().timestamp() as u64,
owned_by: "stability-ai".to_string(),
})
})
.collect();
Ok(models)
}
async fn health_check(&self) -> Result<()> {
debug!("Performing Stability AI health check");
let url = format!("{}/v1/user/account", self.base.base_url);
let response = self
.base
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.base.api_key))
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
if response.status().is_success() {
Ok(())
} else {
Err(
ProviderError::Unknown(format!("Health check failed: {}", response.status()))
.into(),
)
}
}
async fn chat_completion(
&self,
_request: ChatCompletionRequest,
_context: RequestContext,
) -> Result<ChatCompletionResponse> {
Err(ProviderError::InvalidRequest(
"Chat completion not supported by Stability AI".to_string(),
)
.into())
}
async fn completion(
&self,
_request: CompletionRequest,
_context: RequestContext,
) -> Result<CompletionResponse> {
Err(ProviderError::InvalidRequest(
"Text completion not supported by Stability AI".to_string(),
)
.into())
}
async fn embedding(
&self,
_request: EmbeddingRequest,
_context: RequestContext,
) -> Result<EmbeddingResponse> {
Err(
ProviderError::InvalidRequest("Embeddings not supported by Stability AI".to_string())
.into(),
)
}
async fn image_generation(
&self,
request: ImageGenerationRequest,
_context: RequestContext,
) -> Result<ImageGenerationResponse> {
debug!(
"Stability AI image generation for model: {:?}",
request.model
);
let mut body = json!({
"text_prompts": [
{
"text": request.prompt,
"weight": 1.0
}
]
});
if let Some(n) = request.n {
body["samples"] = json!(n);
}
if let Some(size) = &request.size {
if let Some((width, height)) = size.split_once('x') {
if let (Ok(w), Ok(h)) = (width.parse::<u32>(), height.parse::<u32>()) {
body["width"] = json!(w);
body["height"] = json!(h);
}
}
}
body["cfg_scale"] = json!(7);
body["steps"] = json!(30);
let default_model = "stable-diffusion-xl-1024-v1-0".to_string();
let model = request.model.as_ref().unwrap_or(&default_model);
let url = format!(
"{}/v1/generation/{}/text-to-image",
self.base.base_url, model
);
let response = self
.base
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.base.api_key))
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.json(&body)
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(match status.as_u16() {
401 => ProviderError::Authentication(error_text),
429 => ProviderError::RateLimit(error_text),
404 => ProviderError::ModelNotFound(error_text),
400 => ProviderError::InvalidRequest(error_text),
_ => ProviderError::Unknown(format!("HTTP {}: {}", status, error_text)),
}
.into());
}
let stability_response: serde_json::Value = self.base.parse_json_response(response).await?;
self.convert_stability_response_to_openai(stability_response)
}
async fn get_model_pricing(&self, model: &str) -> Result<ModelPricing> {
if let Some(pricing) = self.pricing_cache.get(model) {
Ok(pricing.clone())
} else {
Ok(ModelPricing {
model: model.to_string(),
input_cost_per_1k: 0.0,
output_cost_per_1k: 0.03, currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
})
}
}
async fn calculate_cost(
&self,
model: &str,
_input_tokens: u32,
output_tokens: u32,
) -> Result<f64> {
let pricing = self.get_model_pricing(model).await?;
let cost = (output_tokens as f64) * pricing.output_cost_per_1k;
Ok(cost)
}
}