edgequake-llm 0.5.1

Multi-provider LLM abstraction library with caching, rate limiting, and cost tracking
Documentation
use std::time::{Duration, Instant};

use async_trait::async_trait;
use reqwest::Client;
use serde::Deserialize;
use serde_json::{json, Value};
use tokio::time::sleep;

use crate::imagegen::error::{ImageGenError, Result};
use crate::imagegen::traits::ImageGenProvider;
use crate::imagegen::types::{
    GeneratedImage, ImageGenData, ImageGenRequest, ImageGenResponse, SafetyLevel,
};

const DEFAULT_FAL_MODEL: &str = "fal-ai/flux/dev";
const DEFAULT_TIMEOUT_SECS: u64 = 300;
const DEFAULT_POLL_INTERVAL_MS: u64 = 500;
const MAX_POLL_INTERVAL_MS: u64 = 5_000;

#[derive(Debug, Clone)]
pub struct FalImageGen {
    client: Client,
    api_key: String,
    model: String,
    timeout_secs: u64,
    poll_interval_ms: u64,
}

#[derive(Debug, Deserialize)]
struct FalSubmitResponse {
    request_id: String,
}

#[derive(Debug, Deserialize)]
struct FalStatusResponse {
    status: String,
    #[serde(default)]
    error: Option<String>,
}

#[derive(Debug, Deserialize)]
struct FalResultResponse {
    #[serde(default)]
    images: Vec<FalImage>,
    #[serde(default)]
    seed: Option<u64>,
    #[serde(default)]
    prompt: Option<String>,
    #[serde(default)]
    has_nsfw_concepts: Vec<bool>,
    #[serde(default)]
    timings: Option<FalTimings>,
}

#[derive(Debug, Deserialize)]
struct FalImage {
    url: String,
    #[serde(default)]
    width: u32,
    #[serde(default)]
    height: u32,
    #[serde(default)]
    content_type: Option<String>,
}

#[derive(Debug, Deserialize)]
struct FalTimings {
    #[serde(default)]
    inference: Option<f64>,
}

impl FalImageGen {
    pub fn new(api_key: impl Into<String>) -> Self {
        Self {
            client: Client::new(),
            api_key: api_key.into(),
            model: DEFAULT_FAL_MODEL.to_string(),
            timeout_secs: DEFAULT_TIMEOUT_SECS,
            poll_interval_ms: DEFAULT_POLL_INTERVAL_MS,
        }
    }

    pub fn from_env() -> Result<Self> {
        let api_key = std::env::var("FAL_KEY")
            .map_err(|_| ImageGenError::ConfigError("FAL_KEY must be set".to_string()))?;
        let model =
            std::env::var("IMAGEGEN_FAL_MODEL").unwrap_or_else(|_| DEFAULT_FAL_MODEL.to_string());
        let timeout_secs = std::env::var("IMAGEGEN_FAL_TIMEOUT_SECS")
            .ok()
            .and_then(|value| value.parse().ok())
            .unwrap_or(DEFAULT_TIMEOUT_SECS);
        let poll_interval_ms = std::env::var("IMAGEGEN_FAL_POLL_INTERVAL")
            .ok()
            .and_then(|value| value.parse().ok())
            .unwrap_or(DEFAULT_POLL_INTERVAL_MS);
        Ok(Self::new(api_key)
            .with_model(model)
            .with_timeout_secs(timeout_secs)
            .with_poll_interval_ms(poll_interval_ms))
    }

    pub fn with_model(mut self, model: impl Into<String>) -> Self {
        self.model = model.into();
        self
    }

    pub fn with_timeout_secs(mut self, timeout_secs: u64) -> Self {
        self.timeout_secs = timeout_secs;
        self
    }

    pub fn with_poll_interval_ms(mut self, poll_interval_ms: u64) -> Self {
        self.poll_interval_ms = poll_interval_ms;
        self
    }

    fn active_model<'a>(&'a self, request: &'a ImageGenRequest) -> &'a str {
        request.model.as_deref().unwrap_or(&self.model)
    }

    fn submit_url(&self, endpoint: &str) -> String {
        format!("https://queue.fal.run/{endpoint}")
    }

    fn status_url(&self, endpoint: &str, request_id: &str) -> String {
        format!("https://queue.fal.run/{endpoint}/requests/{request_id}/status")
    }

    fn result_url(&self, endpoint: &str, request_id: &str) -> String {
        format!("https://queue.fal.run/{endpoint}/requests/{request_id}")
    }

    fn build_request_body(&self, request: &ImageGenRequest, endpoint: &str) -> Value {
        let options = &request.options;
        let image_size = match (options.width, options.height) {
            (Some(width), Some(height)) => json!({ "width": width, "height": height }),
            _ => json!(options.aspect_ratio_or_default().as_fal_str()),
        };

        let default_steps = if endpoint.contains("schnell") {
            4
        } else if endpoint.contains("pro") {
            40
        } else {
            28
        };

        let steps = options
            .extra
            .get("steps")
            .and_then(|value| value.as_u64())
            .unwrap_or(default_steps);
        let acceleration = options
            .extra
            .get("acceleration")
            .cloned()
            .unwrap_or_else(|| json!("none"));

        json!({
            "prompt": request.prompt,
            "image_size": image_size,
            "num_inference_steps": steps,
            "guidance_scale": options.guidance_scale.unwrap_or(3.5),
            "num_images": options.count_or_default(),
            "enable_safety_checker": !matches!(options.safety_level_or_default(), SafetyLevel::BlockNone),
            "output_format": options.output_format_or_default().as_fal_str(),
            "acceleration": acceleration,
            "seed": options.seed,
        })
    }

    async fn submit(&self, endpoint: &str, body: &Value) -> Result<String> {
        let response = self
            .client
            .post(self.submit_url(endpoint))
            .header("Authorization", format!("Key {}", self.api_key))
            .json(body)
            .send()
            .await?;
        let status = response.status();
        let body = response.text().await?;

        if !status.is_success() {
            return Err(match status.as_u16() {
                400 | 404 => ImageGenError::InvalidRequest(body),
                401 | 403 => ImageGenError::AuthError(body),
                429 => ImageGenError::RateLimited { retry_after: None },
                _ => ImageGenError::ProviderError(body),
            });
        }

        let payload: FalSubmitResponse = serde_json::from_str(&body)?;
        Ok(payload.request_id)
    }

    async fn poll_until_complete(&self, endpoint: &str, request_id: &str) -> Result<()> {
        let started_at = Instant::now();
        let mut interval_ms = self.poll_interval_ms;

        loop {
            if started_at.elapsed() > Duration::from_secs(self.timeout_secs) {
                return Err(ImageGenError::Timeout {
                    elapsed_secs: self.timeout_secs,
                });
            }

            let response = self
                .client
                .get(self.status_url(endpoint, request_id))
                .header("Authorization", format!("Key {}", self.api_key))
                .send()
                .await?;
            let status = response.status();
            let body = response.text().await?;
            if !status.is_success() {
                return Err(ImageGenError::ProviderError(body));
            }
            let payload: FalStatusResponse = serde_json::from_str(&body)?;
            match payload.status.as_str() {
                "COMPLETED" => return Ok(()),
                "FAILED" => {
                    return Err(ImageGenError::ProviderError(
                        payload
                            .error
                            .unwrap_or_else(|| "FAL queue job failed".to_string()),
                    ))
                }
                _ => {
                    sleep(Duration::from_millis(interval_ms)).await;
                    interval_ms = (interval_ms * 2).min(MAX_POLL_INTERVAL_MS);
                }
            }
        }
    }

    async fn fetch_result(&self, endpoint: &str, request_id: &str) -> Result<FalResultResponse> {
        let response = self
            .client
            .get(self.result_url(endpoint, request_id))
            .header("Authorization", format!("Key {}", self.api_key))
            .send()
            .await?;
        let status = response.status();
        let body = response.text().await?;
        if !status.is_success() {
            return Err(ImageGenError::ProviderError(body));
        }
        Ok(serde_json::from_str(&body)?)
    }
}

#[async_trait]
impl ImageGenProvider for FalImageGen {
    fn name(&self) -> &str {
        "fal-ai"
    }

    fn default_model(&self) -> &str {
        &self.model
    }

    fn available_models(&self) -> Vec<&str> {
        vec![
            "fal-ai/flux/dev",
            "fal-ai/flux/schnell",
            "fal-ai/flux/pro",
            "fal-ai/stable-diffusion-xl",
            "fal-ai/aura-flow",
            "fal-ai/hyper-sdxl",
        ]
    }

    async fn generate(&self, request: &ImageGenRequest) -> Result<ImageGenResponse> {
        if request.prompt.trim().is_empty() {
            return Err(ImageGenError::InvalidRequest(
                "prompt must not be empty".to_string(),
            ));
        }

        let endpoint = self.active_model(request).to_string();
        let submit_body = self.build_request_body(request, &endpoint);
        let request_id = self.submit(&endpoint, &submit_body).await?;
        self.poll_until_complete(&endpoint, &request_id).await?;
        let payload = self.fetch_result(&endpoint, &request_id).await?;

        if payload.has_nsfw_concepts.iter().any(|value| *value) {
            return Err(ImageGenError::ContentFiltered {
                reason: "FAL safety checker flagged the generated output".to_string(),
            });
        }

        let latency_ms = payload
            .timings
            .and_then(|timings| timings.inference)
            .map(|seconds| (seconds * 1000.0) as u64)
            .unwrap_or_default();
        let images = payload
            .images
            .into_iter()
            .map(|image| GeneratedImage {
                data: ImageGenData::Url(image.url),
                width: image.width,
                height: image.height,
                mime_type: image.content_type.unwrap_or_else(|| {
                    request
                        .options
                        .output_format_or_default()
                        .mime_type()
                        .to_string()
                }),
                seed: payload.seed,
            })
            .collect::<Vec<_>>();

        if images.is_empty() {
            return Err(ImageGenError::InvalidResponse(
                "FAL result did not include any images".to_string(),
            ));
        }

        Ok(ImageGenResponse {
            images,
            provider: self.name().to_string(),
            model: endpoint,
            latency_ms,
            enhanced_prompt: payload.prompt,
        })
    }
}

#[cfg(test)]
mod tests {
    use super::FalImageGen;
    use crate::imagegen::types::{AspectRatio, ImageGenOptions, ImageGenRequest};

    #[test]
    fn test_build_request_body_uses_ratio_and_seed() {
        let provider = FalImageGen::new("key");
        let request = ImageGenRequest::new("test").with_options(ImageGenOptions {
            aspect_ratio: Some(AspectRatio::Landscape169),
            seed: Some(42),
            ..Default::default()
        });
        let body = provider.build_request_body(&request, "fal-ai/flux/dev");
        assert_eq!(body["image_size"], "landscape_16_9");
        assert_eq!(body["seed"], 42);
    }
}