use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ImageSize {
Preset(String),
Custom { width: u32, height: u32 },
}
impl ImageSize {
pub fn from_openai_size(size: &str) -> Self {
match size {
"1024x1024" => ImageSize::Preset("square_hd".to_string()),
"512x512" => ImageSize::Preset("square".to_string()),
"1792x1024" => ImageSize::Preset("landscape_16_9".to_string()),
"1024x1792" => ImageSize::Preset("portrait_16_9".to_string()),
"1024x768" => ImageSize::Preset("landscape_4_3".to_string()),
"768x1024" => ImageSize::Preset("portrait_4_3".to_string()),
_ => {
if let Some((w, h)) = size.split_once('x')
&& let (Ok(width), Ok(height)) = (w.parse(), h.parse())
{
return ImageSize::Custom { width, height };
}
ImageSize::Preset("landscape_4_3".to_string())
}
}
}
}
#[derive(Debug, Clone)]
pub struct FalAIModel {
pub id: String,
pub name: String,
pub description: String,
pub cost_per_image: f64,
pub supported_sizes: Vec<String>,
pub max_images: u32,
pub supports_prompt_enhancement: bool,
}
impl FalAIModel {
pub fn new(id: &str, name: &str, description: &str, cost_per_image: f64) -> Self {
Self {
id: id.to_string(),
name: name.to_string(),
description: description.to_string(),
cost_per_image,
supported_sizes: vec![
"square".to_string(),
"square_hd".to_string(),
"portrait_4_3".to_string(),
"portrait_16_9".to_string(),
"landscape_4_3".to_string(),
"landscape_16_9".to_string(),
],
max_images: 4,
supports_prompt_enhancement: false,
}
}
pub fn with_max_images(mut self, max: u32) -> Self {
self.max_images = max;
self
}
pub fn with_sizes(mut self, sizes: Vec<String>) -> Self {
self.supported_sizes = sizes;
self
}
pub fn with_prompt_enhancement(mut self) -> Self {
self.supports_prompt_enhancement = true;
self
}
}
#[derive(Debug, Clone)]
pub struct FalAIModelRegistry {
models: HashMap<String, FalAIModel>,
}
impl Default for FalAIModelRegistry {
fn default() -> Self {
Self::new()
}
}
impl FalAIModelRegistry {
pub fn new() -> Self {
let mut models = HashMap::new();
models.insert(
"fal-ai/flux/schnell".to_string(),
FalAIModel::new(
"fal-ai/flux/schnell",
"Flux Schnell",
"Fast high-quality image generation",
0.003,
),
);
models.insert(
"fal-ai/flux-pro/v1.1".to_string(),
FalAIModel::new(
"fal-ai/flux-pro/v1.1",
"Flux Pro v1.1",
"Professional quality image generation",
0.05,
)
.with_prompt_enhancement(),
);
models.insert(
"fal-ai/flux-pro/v1.1-ultra".to_string(),
FalAIModel::new(
"fal-ai/flux-pro/v1.1-ultra",
"Flux Pro v1.1 Ultra",
"Ultra high-quality image generation",
0.06,
)
.with_prompt_enhancement(),
);
models.insert(
"fal-ai/stable-diffusion-v3-medium".to_string(),
FalAIModel::new(
"fal-ai/stable-diffusion-v3-medium",
"Stable Diffusion 3 Medium",
"Stable Diffusion 3 medium quality",
0.035,
),
);
models.insert(
"fal-ai/recraft/v3/text-to-image".to_string(),
FalAIModel::new(
"fal-ai/recraft/v3/text-to-image",
"Recraft V3",
"High-quality artistic image generation",
0.04,
),
);
models.insert(
"fal-ai/imagen4/preview".to_string(),
FalAIModel::new(
"fal-ai/imagen4/preview",
"Imagen 4 Preview",
"Google Imagen 4 preview model",
0.04,
),
);
models.insert(
"fal-ai/ideogram/v3".to_string(),
FalAIModel::new(
"fal-ai/ideogram/v3",
"Ideogram V3",
"Ideogram text-to-image model",
0.08,
),
);
models.insert(
"fal-ai/bria/text-to-image/hd".to_string(),
FalAIModel::new(
"fal-ai/bria/text-to-image/hd",
"BRIA HD",
"BRIA high-definition image generation",
0.02,
),
);
Self { models }
}
pub fn get(&self, model_id: &str) -> Option<&FalAIModel> {
self.models.get(model_id)
}
pub fn is_supported(&self, model_id: &str) -> bool {
self.models.contains_key(model_id)
}
pub fn list_models(&self) -> Vec<&FalAIModel> {
self.models.values().collect()
}
pub fn get_cost_per_image(&self, model_id: &str) -> f64 {
self.models
.get(model_id)
.map(|m| m.cost_per_image)
.unwrap_or(0.0)
}
pub fn register(&mut self, model: FalAIModel) {
self.models.insert(model.id.clone(), model);
}
}
pub const SUPPORTED_OPENAI_PARAMS: &[&str] = &["n", "response_format", "size"];
pub fn map_openai_to_fal_params(params: &serde_json::Value) -> serde_json::Value {
let mut fal_params = serde_json::Map::new();
if let Some(obj) = params.as_object() {
for (key, value) in obj {
match key.as_str() {
"n" => {
fal_params.insert("num_images".to_string(), value.clone());
}
"response_format" => {
let format = match value.as_str() {
Some("b64_json") | Some("url") => "jpeg",
Some(f) => f,
None => "jpeg",
};
fal_params.insert("output_format".to_string(), serde_json::json!(format));
}
"size" => {
if let Some(size_str) = value.as_str() {
let image_size = ImageSize::from_openai_size(size_str);
fal_params.insert(
"image_size".to_string(),
serde_json::to_value(image_size).unwrap_or_default(),
);
}
}
_ => {
fal_params.insert(key.clone(), value.clone());
}
}
}
}
serde_json::Value::Object(fal_params)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_image_size_from_openai_standard() {
let size = ImageSize::from_openai_size("1024x1024");
match size {
ImageSize::Preset(s) => assert_eq!(s, "square_hd"),
_ => panic!("Expected preset size"),
}
}
#[test]
fn test_image_size_from_openai_custom() {
let size = ImageSize::from_openai_size("800x600");
match size {
ImageSize::Custom { width, height } => {
assert_eq!(width, 800);
assert_eq!(height, 600);
}
_ => panic!("Expected custom size"),
}
}
#[test]
fn test_image_size_from_openai_invalid() {
let size = ImageSize::from_openai_size("invalid");
match size {
ImageSize::Preset(s) => assert_eq!(s, "landscape_4_3"),
_ => panic!("Expected fallback preset"),
}
}
#[test]
fn test_model_registry_default() {
let registry = FalAIModelRegistry::new();
assert!(registry.is_supported("fal-ai/flux/schnell"));
assert!(registry.is_supported("fal-ai/flux-pro/v1.1"));
}
#[test]
fn test_model_registry_get() {
let registry = FalAIModelRegistry::new();
let model = registry.get("fal-ai/flux/schnell");
assert!(model.is_some());
assert_eq!(model.unwrap().name, "Flux Schnell");
}
#[test]
fn test_model_registry_cost() {
let registry = FalAIModelRegistry::new();
let cost = registry.get_cost_per_image("fal-ai/flux/schnell");
assert!(cost > 0.0);
}
#[test]
fn test_model_registry_unknown() {
let registry = FalAIModelRegistry::new();
assert!(!registry.is_supported("unknown-model"));
assert_eq!(registry.get_cost_per_image("unknown-model"), 0.0);
}
#[test]
fn test_map_openai_to_fal_params() {
let openai_params = serde_json::json!({
"n": 2,
"size": "1024x1024",
"response_format": "url"
});
let fal_params = map_openai_to_fal_params(&openai_params);
assert_eq!(fal_params["num_images"], 2);
assert_eq!(fal_params["output_format"], "jpeg");
assert!(fal_params.get("image_size").is_some());
}
#[test]
fn test_model_with_builder() {
let model = FalAIModel::new("test", "Test Model", "A test model", 0.01)
.with_max_images(8)
.with_prompt_enhancement();
assert_eq!(model.max_images, 8);
assert!(model.supports_prompt_enhancement);
}
#[test]
fn test_register_custom_model() {
let mut registry = FalAIModelRegistry::new();
let custom = FalAIModel::new("custom/model", "Custom", "Custom model", 0.1);
registry.register(custom);
assert!(registry.is_supported("custom/model"));
}
#[test]
fn test_list_models() {
let registry = FalAIModelRegistry::new();
let models = registry.list_models();
assert!(!models.is_empty());
}
}