banana_rust_sdk/
utils.rs

1use std::time::SystemTime;
2use reqwest::StatusCode;
3use serde_json::Value;
4use uuid::Uuid;
5use lazy_static::lazy_static;
6
7use crate::types::BananaError;
8use crate::types::BananaResponse;
9use crate::types::Payload;
10use crate::types::CheckPayload;
11
12// getting the url as a global variabel at run time
13lazy_static! {
14    static ref BANANA_URL: String = {
15        match std::env::var("BANANA_URL") {
16           Ok(v) => {
17            println!("dev mode");
18            if v == "local".to_string() {
19                return "http://localhost/".to_string();
20            } else {
21                return v;
22            }
23        },
24            Err(_) => {
25                return "https://api.banana.dev/".to_string(); 
26            } 
27        };
28    };
29}
30
31pub async fn run_main(api_key: &str, model_key: &str, model_inputs: Value) -> Result<BananaResponse, BananaError> {
32    match start_api(api_key, model_key, model_inputs).await {
33        Ok(res) => {
34            match res.finished {
35                Some(value) => {
36                    if value == true {
37                        return Ok(res)
38                    } else {
39                        match res.call_i_d {
40                            Some(value) => {
41                                loop {
42                                    println!("polling...");
43                                    match check_api(api_key, &value).await {
44                                        Ok(res) => {
45                                            if res.message.to_ascii_lowercase() == "success" {
46                                                return Ok(res)
47                                            }
48                                        }
49                                        Err(e) => return Err(e)
50                                    }
51                                }         
52        
53                            },
54                            None => return Err(BananaError::ResponseError("call id returned undefined".to_string()))
55                        }
56                    }
57                },
58                None => return Err(BananaError::ResponseError("finished returned undefined.".to_string()))
59            }
60        },
61        Err(e) => Err(e)
62    }
63}
64
65pub async fn start_main(api_key: &str, model_key: &str, model_inputs: Value) -> Result<String, BananaError> {
66    match start_api(api_key, model_key, model_inputs).await {
67        Ok(res) => {
68            match res.call_i_d {
69                Some(value) => return Ok(value),
70                None => return Err(BananaError::ResponseError("call id returned undefined.".to_string()))
71            }
72        }  
73        Err(e) => Err(e)
74    }
75}
76
77pub async fn check_main(api_key: &str, call_id: &String) -> Result<BananaResponse, BananaError> {
78    check_api(api_key, call_id).await
79}
80
81// ------------- API CALLING FUNCTIONS ------------- 
82// ------------------------------------------------- 
83
84
85async fn start_api(api_key: &str, model_key: &str, model_inputs: Value) -> Result<BananaResponse, BananaError> {
86
87    // accessing the global url variable which has to be cloned 
88    let mut url_start = BANANA_URL.clone();
89    let route_start = "start/v4/";
90    url_start.push_str(route_start);
91    
92    let created = match SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) {
93        Ok(time) => time.as_millis() as usize,
94        Err(e) => return Err(BananaError::TimeError(e))
95    };
96
97    let payload = Payload {
98        id: Uuid::new_v4().to_string(),
99        created,
100        model_key: model_key.to_string(),
101        api_key: api_key.to_string(),
102        model_inputs,
103        start_only: false
104    };
105
106    let client = reqwest::Client::new();
107
108    match client.post(&url_start).json(&payload).send().await {
109        Ok(res) => {
110            
111            let status = res.status();
112            
113            if status != StatusCode::OK {
114                Err(BananaError::ServerError(status.to_string()))
115            } else {
116                match res.json::<BananaResponse>().await {
117                    Ok(res) => {
118                        if res.message.to_ascii_lowercase().contains("error") {
119                            return Err(BananaError::ModelError(res.message))
120                        } else {
121                            Ok(res)
122                        }
123                    },
124                    Err(e) => Err(BananaError::JsonError(e))
125
126                }    
127            }
128        },
129        Err(e) => Err(BananaError::ConnectionError(e))
130    }
131}
132
133async fn check_api(api_key: &str, call_id: &String) -> Result<BananaResponse, BananaError> {
134    
135    let mut url_check = BANANA_URL.clone();
136    println!("Hitting endpoint: {}", url_check);
137    let route_start = "check/v4/";
138    url_check.push_str(route_start);
139
140    let created = match SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) {
141        Ok(time) => time.as_millis() as usize,
142        Err(e) => return Err(BananaError::TimeError(e))
143    };
144
145    let payload = CheckPayload {
146        id: Uuid::new_v4().to_string(),
147        created,
148        long_poll: true,
149        call_i_d: call_id.to_string(),
150        api_key: api_key.to_string()
151    };
152
153    let client = reqwest::Client::new();
154
155    match client.post(url_check).json(&payload).send().await {
156        Ok(res) => {
157            
158            let status = res.status();
159            
160            if status != StatusCode::OK {
161                Err(BananaError::ServerError(status.to_string()))
162            } else {
163                let json = res.json::<BananaResponse>().await;
164                match json {
165                    Ok(res) => {
166                        if res.message.to_ascii_lowercase().contains("error") {
167                            return Err(BananaError::ModelError(res.message))
168                        } else {
169                            Ok(res)
170                        }
171                    },
172                    Err(e) => Err(BananaError::JsonError(e))
173                }    
174            }
175        },
176        Err(e) => Err(BananaError::ConnectionError(e))
177    }
178}