1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
//! The Banana sdk contains three simple asycn functions to call the [Banana](https://www.banana.dev/) services API.
//! 
//! We're moving fast and so we will most likely not prioritize backwards compatibility.
//! The run() function is what you'll use 99% of the time and the other can be seen as helper functions
//! 
//! # Examples
//!
//! Basic usage:
//!
//! ```
//! use banana_rust_sdk;
//! use serde::Serialize;
//!
//!#[tokio::main]
//!async fn main() {
//!    #[derive(Serialize)]
//!    struct ModelInputs {
//!        prompt: String
//!    }
//!    
//!    let api_key = "API_KEY";
//!    let model_key = "MODEL_KEY";
//!    let model_inputs = ModelInputs {
//!        prompt: "try to predict the next [MASK] of this sentence.".to_string()
//!    };
//!
//!    let model_inputs = serde_json::to_value(model_inputs).unwrap();
//!
//!    let res = banana_rust_sdk::run(api_key, model_key, model_inputs).await.unwrap();
//!    let json = serde_json::to_value(res).unwrap();
//!    println!("{:?}", json);
//!}
//! ```

use crate::types::BananaError;
use crate::types::BananaResponse;
use crate::utils::run_main;
use crate::utils::check_main;
use crate::utils::start_main;
use serde_json::Value;

pub mod utils;
pub mod types;

/// The main function for calling your model on Banana
/// 
/// # Example
/// ```
/// use banana_rust_sdk;
/// use serde::Serialize;
///
///#[tokio::main]
///async fn main() {
///    #[derive(Serialize)]
///    struct ModelInputs {
///        prompt: String
///    }
///    
///    let api_key = "API_KEY";
///    let model_key = "MODEL_KEY";
///    let model_inputs = ModelInputs {
///        prompt: "try to predict the next [MASK] of this sentence.".to_string()
///    };
///
///    let model_inputs = serde_json::to_value(model_inputs).unwrap();
///
///    let res = banana_rust_sdk::run(api_key, model_key, model_inputs).await.unwrap();
///    let json = serde_json::to_value(res).unwrap();
///    println!("{:?}", json);
///}
/// ```
pub async fn run(api_key: &str, model_key: &str, model_inputs: Value) -> Result<BananaResponse, BananaError> {
    run_main(api_key, model_key, model_inputs).await
}

/// Call API without checking the queue, in general using this should be avoided
pub async fn start(api_key: &str, model_key: &str, model_inputs: Value) -> Result<String, BananaError> {
    start_main(api_key, model_key, model_inputs).await
}

/// Helperfunction to check if there are items in the queue
pub async fn check(api_key: &str, call_id: &String) -> Result<BananaResponse, BananaError> {
    check_main(api_key, call_id).await
}


#[cfg(test)]
mod tests {
    use super::*;
    use serde::Serialize;

    #[tokio::test]
    async fn test_run() {
        
        #[derive(Serialize)]
        struct ModelInputs {
            prompt: String
        }
        
        let api_key = "API_KEY";
        let model_key = "MODEL_KEY";
        let model_inputs = ModelInputs {
            prompt: "Paris is the caiptal of [MASK]".to_string()
        };

        let model_inputs = serde_json::to_value(model_inputs).unwrap();

        // Note: since the model inputs are whatevery the user defines them to be, we can't check
        // the types of model inputs in the Banana API
        // Writing a wrapper function around the run function that does #this is recommended.

        run(api_key, model_key, model_inputs).await.unwrap();
    }

    #[tokio::test]
    async fn test_start() {
        #[derive(Serialize)]
        struct ModelInputs {
            prompt: String
        }
        
        let api_key = "API_KEY";
        let model_key = "MODEL_KEY";
        let model_inputs = ModelInputs {
            prompt: "Paris is the capital of [MASK]".to_string()
        };

        let model_inputs = serde_json::to_value(model_inputs).unwrap();

        start(api_key, model_key, model_inputs).await.unwrap();
    }

    #[tokio::test]
    async fn test_check() {
        #[derive(Serialize)]
        struct ModelInputs {
            prompt: String
        }
        
        let api_key = "API_KEY";
        let model_key = "MODEL_KEY";
        let model_inputs = ModelInputs {
            prompt: "Paris is the capital of [MASK]".to_string()
        };

        let model_inputs = serde_json::to_value(model_inputs).unwrap();

        match run(api_key, model_key, model_inputs).await {
            Ok(res) => {
                match check(api_key, &res.call_i_d.unwrap()).await {
                    Ok(out) => println!("{:#?}", out),
                    Err(e) => println!("{:#?}", e)
                };
            },
            Err(_) => {
                panic!("not able to call run() in test_check")
            }
        }
    }
}