kai_rs/
lib.rs

1mod generation_response;
2mod generation_settings;
3mod model;
4pub mod prelude;
5mod string_response;
6
7use anyhow::{bail, Result};
8use generation_response::{GenerationErrorResult, GenerationOkResult};
9use generation_settings::GenerationSettings;
10use model::Model;
11use serde_json::{json, Value};
12use std::fmt;
13
14pub enum APIVersion {
15    V1,
16}
17
18impl fmt::Display for APIVersion {
19    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
20        match self {
21            APIVersion::V1 => write!(f, "v1"),
22        }
23    }
24}
25
26pub struct KoboldClient {
27    api_url: String,
28    client: reqwest::Client,
29}
30
31impl KoboldClient {
32    pub fn new(api_url: &str, api_version: APIVersion) -> Self {
33        let api_url = format!("{api_url}/api/{api_version}");
34        let client = reqwest::Client::new();
35
36        KoboldClient { api_url, client }
37    }
38
39    fn get_result_string(value: serde_json::Value) -> Result<String> {
40        let string = value
41            .get("result")
42            .expect("Property 'result' not found.")
43            .as_str()
44            .expect("Could not convert 'result' into a string.");
45
46        Ok(String::from(string))
47    }
48
49    pub async fn get_version(&self) -> Result<String> {
50        let response = self
51            .client
52            .get(format!("{}/info/version", self.api_url))
53            .send()
54            .await?;
55
56        let version = Self::get_result_string(response.json().await?)?;
57
58        Ok(version)
59    }
60
61    pub async fn get_model(&self) -> Result<Option<Model>> {
62        let response = self
63            .client
64            .get(format!("{}/model", self.api_url))
65            .send()
66            .await?;
67
68        let model_string = Self::get_result_string(response.json().await?)?;
69
70        if model_string.eq("ReadOnly") {
71            Ok(None)
72        } else {
73            Ok(Some(Model::from(model_string)))
74        }
75    }
76
77    pub async fn load_model(&self, model: Model, gpu_layers: Vec<i32>) -> Result<()> {
78        let gpu_layers_breakmodel: String = gpu_layers
79            .iter()
80            .map(|i| i.to_string())
81            .collect::<Vec<String>>()
82            .join(",");
83
84        let response = self
85            .client
86            .put(format!("{}/model", self.api_url))
87            .json(&json!({ "model": model, "gpu_layers": gpu_layers_breakmodel }))
88            .send()
89            .await?;
90
91        if response.status().is_success() {
92            Ok(())
93        } else {
94            bail!(response.status())
95        }
96    }
97
98    pub async fn generate(
99        &self,
100        prompt: &str,
101        settings: GenerationSettings,
102    ) -> Result<Vec<String>> {
103        let settings_value = inject_prompt(prompt, settings)?;
104
105        let response = self
106            .client
107            .post(format!("{}/generate", self.api_url))
108            .json(&settings_value)
109            .send()
110            .await?;
111
112        if response.status().is_success() {
113            let generation_response = response.json::<GenerationOkResult>().await?;
114
115            let generations = generation_response
116                .results
117                .into_iter()
118                .map(|generation| generation.text)
119                .collect();
120
121            Ok(generations)
122        } else {
123            println!("HTTP {}", response.status());
124
125            let error_response = &response.json::<GenerationErrorResult>().await?;
126
127            bail!(format!(
128                "[{:?}] {}",
129                &error_response.detail.error_type, &error_response.detail.message
130            ));
131        }
132    }
133}
134
135fn inject_prompt(prompt: &str, settings: GenerationSettings) -> Result<Value> {
136    let mut settings_value = serde_json::to_value(&settings)?;
137
138    // inject prompt into settings object before serializing
139    settings_value = match settings_value {
140        Value::Object(m) => {
141            let mut m = m.clone();
142            m.insert(
143                "prompt".into(),
144                serde_json::Value::String(prompt.to_string()),
145            );
146
147            Value::Object(m)
148        }
149        _ => bail!("GenerationSettings object was not an object!"),
150    };
151
152    Ok(settings_value)
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    // ! Tests prefixed with "debug" are not unit tests. You need KoboldAI running to use them.
160
161    #[tokio::test]
162    async fn debug_get_version() {
163        let kai = KoboldClient::new("http://localhost:5000", APIVersion::V1);
164        kai.get_version().await.unwrap();
165    }
166
167    #[tokio::test]
168    async fn debug_get_model() {
169        let kai = KoboldClient::new("http://localhost:5000", APIVersion::V1);
170        let model = kai.get_model().await.unwrap();
171        dbg!(model);
172    }
173
174    #[tokio::test]
175    async fn debug_load_model() {
176        let kai = KoboldClient::new("http://localhost:5000", APIVersion::V1);
177        // ! System dependent
178        let gpu_layers = vec![28];
179        kai.load_model(Model::from("./pygmalion-6b_dev"), gpu_layers)
180            .await
181            .unwrap();
182    }
183
184    #[tokio::test]
185    async fn debug_generate() {
186        let kai = KoboldClient::new("http://localhost:5000", APIVersion::V1);
187        let settings = GenerationSettings::default();
188
189        let prompt = "You: Hi. How are you?";
190        let response = kai.generate(prompt, settings).await;
191
192        match response {
193            Ok(generations) => {
194                println!("{}{}", prompt, generations[0]);
195                assert!(true)
196            }
197            Err(err) => {
198                dbg!(err);
199                assert!(false)
200            }
201        };
202    }
203
204    #[test]
205    fn inject_prompt_into_generation_settings_object() {
206        let settings = GenerationSettings::default();
207        let result = inject_prompt("test".into(), settings).unwrap();
208
209        assert_eq!(
210            result
211                .as_object()
212                .unwrap()
213                .get("prompt")
214                .unwrap()
215                .as_str()
216                .unwrap(),
217            "test"
218        )
219    }
220
221    #[test]
222    fn decode_successful_generation() {
223        let value = json!({
224            "results": [{
225                "text": "testing"
226            }]
227        });
228
229        let ok_result = serde_json::from_value::<GenerationOkResult>(value).unwrap();
230
231        assert_eq!(ok_result.results[0].text, "testing")
232    }
233
234    #[test]
235    fn decode_failed_generation() {
236        let value = json!({
237            "detail": {
238                "msg": "test",
239                "type": "not_implemented"
240            }
241        });
242
243        let err_result = serde_json::from_value::<GenerationErrorResult>(value).unwrap();
244
245        assert_eq!(err_result.detail.message, "test")
246    }
247}