fal_rust/
client.rs

1use reqwest::{header, Response};
2
3use super::error::FalError;
4
5#[derive(Debug, Clone)]
6pub enum ClientCredentials {
7    Key(String),
8    FromEnv(String),
9    KeyPair(String, String),
10}
11
12const ENV_CANDIDATES: [&str; 2] = ["FAL_KEY", "FAL_API_KEY"];
13
14impl ClientCredentials {
15    pub fn from_env() -> Self {
16        dotenvy::dotenv().ok();
17
18        for candidate in ENV_CANDIDATES {
19            if let Ok(key) = std::env::var(candidate) {
20                return Self::Key(key);
21            }
22        }
23
24        if let Ok(key_id) = std::env::var("FAL_KEY_ID") {
25            if let Ok(secret) = std::env::var("FAL_SECRET") {
26                return Self::KeyPair(key_id, secret);
27            }
28        }
29
30        panic!("FAL_KEY or FAL_KEY_ID and FAL_SECRET must be set in the environment");
31    }
32
33    pub fn from_key(key: &str) -> Self {
34        Self::Key(key.to_string())
35    }
36
37    pub fn from_key_pair(key_id: &str, secret: &str) -> Self {
38        Self::KeyPair(key_id.to_string(), secret.to_string())
39    }
40}
41
42#[derive(Debug, Clone)]
43pub struct FalClient {
44    credentials: ClientCredentials,
45}
46
47impl FalClient {
48    pub fn build_url(&self, path: &str) -> String {
49        let host = std::env::var("FAL_RUN_HOST").unwrap_or("fal.run".to_string());
50        let base_url = format!("https://{}", host);
51        format!("{}/{}", base_url, path.trim_start_matches('/'))
52    }
53
54    pub fn new(credentials: ClientCredentials) -> Self {
55        Self { credentials }
56    }
57
58    fn client(&self) -> reqwest::Client {
59        let mut header = header::HeaderMap::new();
60        let creds = match &self.credentials {
61            ClientCredentials::Key(key) => key.clone(),
62            ClientCredentials::KeyPair(key_id, secret) => format!("{}:{}", key_id, secret),
63            ClientCredentials::FromEnv(_) => panic!("FAL_API_KEY must be set in the environment"),
64        };
65
66        header.insert("Authorization", format!("Key {}", creds).parse().unwrap());
67
68        reqwest::Client::builder()
69            .default_headers(header)
70            .build()
71            .unwrap()
72    }
73
74    pub async fn run<T: serde::Serialize>(
75        &self,
76        funtion_id: &str,
77        inputs: T,
78    ) -> Result<Response, FalError> {
79        let client = self.client();
80        let url = self.build_url(funtion_id);
81        let res = client.post(&url).json(&inputs).send().await;
82
83        match res {
84            Ok(res) => {
85                if res.status().is_success() {
86                    Ok(res)
87                } else {
88                    Err(FalError::InvalidCredentials)
89                }
90            }
91            Err(e) => Err(FalError::RequestError(e)),
92        }
93    }
94}