1use reqwest::{header, Response};
2
3use super::error::FalError;
4
5#[derive(Debug, Clone)]
6pub enum ClientCredentials {
7 Key(String),
8 FromEnv(String),
9 KeyPair(String, String),
10}
11
12const ENV_CANDIDATES: [&str; 2] = ["FAL_KEY", "FAL_API_KEY"];
13
14impl ClientCredentials {
15 pub fn from_env() -> Self {
16 dotenvy::dotenv().ok();
17
18 for candidate in ENV_CANDIDATES {
19 if let Ok(key) = std::env::var(candidate) {
20 return Self::Key(key);
21 }
22 }
23
24 if let Ok(key_id) = std::env::var("FAL_KEY_ID") {
25 if let Ok(secret) = std::env::var("FAL_SECRET") {
26 return Self::KeyPair(key_id, secret);
27 }
28 }
29
30 panic!("FAL_KEY or FAL_KEY_ID and FAL_SECRET must be set in the environment");
31 }
32
33 pub fn from_key(key: &str) -> Self {
34 Self::Key(key.to_string())
35 }
36
37 pub fn from_key_pair(key_id: &str, secret: &str) -> Self {
38 Self::KeyPair(key_id.to_string(), secret.to_string())
39 }
40}
41
42#[derive(Debug, Clone)]
43pub struct FalClient {
44 credentials: ClientCredentials,
45}
46
47impl FalClient {
48 pub fn build_url(&self, path: &str) -> String {
49 let host = std::env::var("FAL_RUN_HOST").unwrap_or("fal.run".to_string());
50 let base_url = format!("https://{}", host);
51 format!("{}/{}", base_url, path.trim_start_matches('/'))
52 }
53
54 pub fn new(credentials: ClientCredentials) -> Self {
55 Self { credentials }
56 }
57
58 fn client(&self) -> reqwest::Client {
59 let mut header = header::HeaderMap::new();
60 let creds = match &self.credentials {
61 ClientCredentials::Key(key) => key.clone(),
62 ClientCredentials::KeyPair(key_id, secret) => format!("{}:{}", key_id, secret),
63 ClientCredentials::FromEnv(_) => panic!("FAL_API_KEY must be set in the environment"),
64 };
65
66 header.insert("Authorization", format!("Key {}", creds).parse().unwrap());
67
68 reqwest::Client::builder()
69 .default_headers(header)
70 .build()
71 .unwrap()
72 }
73
74 pub async fn run<T: serde::Serialize>(
75 &self,
76 funtion_id: &str,
77 inputs: T,
78 ) -> Result<Response, FalError> {
79 let client = self.client();
80 let url = self.build_url(funtion_id);
81 let res = client.post(&url).json(&inputs).send().await;
82
83 match res {
84 Ok(res) => {
85 if res.status().is_success() {
86 Ok(res)
87 } else {
88 Err(FalError::InvalidCredentials)
89 }
90 }
91 Err(e) => Err(FalError::RequestError(e)),
92 }
93 }
94}