1mod generation_response;
2mod generation_settings;
3mod model;
4pub mod prelude;
5mod string_response;
6
7use anyhow::{bail, Result};
8use generation_response::{GenerationErrorResult, GenerationOkResult};
9use generation_settings::GenerationSettings;
10use model::Model;
11use serde_json::{json, Value};
12use std::fmt;
13
14pub enum APIVersion {
15 V1,
16}
17
18impl fmt::Display for APIVersion {
19 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
20 match self {
21 APIVersion::V1 => write!(f, "v1"),
22 }
23 }
24}
25
26pub struct KoboldClient {
27 api_url: String,
28 client: reqwest::Client,
29}
30
31impl KoboldClient {
32 pub fn new(api_url: &str, api_version: APIVersion) -> Self {
33 let api_url = format!("{api_url}/api/{api_version}");
34 let client = reqwest::Client::new();
35
36 KoboldClient { api_url, client }
37 }
38
39 fn get_result_string(value: serde_json::Value) -> Result<String> {
40 let string = value
41 .get("result")
42 .expect("Property 'result' not found.")
43 .as_str()
44 .expect("Could not convert 'result' into a string.");
45
46 Ok(String::from(string))
47 }
48
49 pub async fn get_version(&self) -> Result<String> {
50 let response = self
51 .client
52 .get(format!("{}/info/version", self.api_url))
53 .send()
54 .await?;
55
56 let version = Self::get_result_string(response.json().await?)?;
57
58 Ok(version)
59 }
60
61 pub async fn get_model(&self) -> Result<Option<Model>> {
62 let response = self
63 .client
64 .get(format!("{}/model", self.api_url))
65 .send()
66 .await?;
67
68 let model_string = Self::get_result_string(response.json().await?)?;
69
70 if model_string.eq("ReadOnly") {
71 Ok(None)
72 } else {
73 Ok(Some(Model::from(model_string)))
74 }
75 }
76
77 pub async fn load_model(&self, model: Model, gpu_layers: Vec<i32>) -> Result<()> {
78 let gpu_layers_breakmodel: String = gpu_layers
79 .iter()
80 .map(|i| i.to_string())
81 .collect::<Vec<String>>()
82 .join(",");
83
84 let response = self
85 .client
86 .put(format!("{}/model", self.api_url))
87 .json(&json!({ "model": model, "gpu_layers": gpu_layers_breakmodel }))
88 .send()
89 .await?;
90
91 if response.status().is_success() {
92 Ok(())
93 } else {
94 bail!(response.status())
95 }
96 }
97
98 pub async fn generate(
99 &self,
100 prompt: &str,
101 settings: GenerationSettings,
102 ) -> Result<Vec<String>> {
103 let settings_value = inject_prompt(prompt, settings)?;
104
105 let response = self
106 .client
107 .post(format!("{}/generate", self.api_url))
108 .json(&settings_value)
109 .send()
110 .await?;
111
112 if response.status().is_success() {
113 let generation_response = response.json::<GenerationOkResult>().await?;
114
115 let generations = generation_response
116 .results
117 .into_iter()
118 .map(|generation| generation.text)
119 .collect();
120
121 Ok(generations)
122 } else {
123 println!("HTTP {}", response.status());
124
125 let error_response = &response.json::<GenerationErrorResult>().await?;
126
127 bail!(format!(
128 "[{:?}] {}",
129 &error_response.detail.error_type, &error_response.detail.message
130 ));
131 }
132 }
133}
134
135fn inject_prompt(prompt: &str, settings: GenerationSettings) -> Result<Value> {
136 let mut settings_value = serde_json::to_value(&settings)?;
137
138 settings_value = match settings_value {
140 Value::Object(m) => {
141 let mut m = m.clone();
142 m.insert(
143 "prompt".into(),
144 serde_json::Value::String(prompt.to_string()),
145 );
146
147 Value::Object(m)
148 }
149 _ => bail!("GenerationSettings object was not an object!"),
150 };
151
152 Ok(settings_value)
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[tokio::test]
162 async fn debug_get_version() {
163 let kai = KoboldClient::new("http://localhost:5000", APIVersion::V1);
164 kai.get_version().await.unwrap();
165 }
166
167 #[tokio::test]
168 async fn debug_get_model() {
169 let kai = KoboldClient::new("http://localhost:5000", APIVersion::V1);
170 let model = kai.get_model().await.unwrap();
171 dbg!(model);
172 }
173
174 #[tokio::test]
175 async fn debug_load_model() {
176 let kai = KoboldClient::new("http://localhost:5000", APIVersion::V1);
177 let gpu_layers = vec![28];
179 kai.load_model(Model::from("./pygmalion-6b_dev"), gpu_layers)
180 .await
181 .unwrap();
182 }
183
184 #[tokio::test]
185 async fn debug_generate() {
186 let kai = KoboldClient::new("http://localhost:5000", APIVersion::V1);
187 let settings = GenerationSettings::default();
188
189 let prompt = "You: Hi. How are you?";
190 let response = kai.generate(prompt, settings).await;
191
192 match response {
193 Ok(generations) => {
194 println!("{}{}", prompt, generations[0]);
195 assert!(true)
196 }
197 Err(err) => {
198 dbg!(err);
199 assert!(false)
200 }
201 };
202 }
203
204 #[test]
205 fn inject_prompt_into_generation_settings_object() {
206 let settings = GenerationSettings::default();
207 let result = inject_prompt("test".into(), settings).unwrap();
208
209 assert_eq!(
210 result
211 .as_object()
212 .unwrap()
213 .get("prompt")
214 .unwrap()
215 .as_str()
216 .unwrap(),
217 "test"
218 )
219 }
220
221 #[test]
222 fn decode_successful_generation() {
223 let value = json!({
224 "results": [{
225 "text": "testing"
226 }]
227 });
228
229 let ok_result = serde_json::from_value::<GenerationOkResult>(value).unwrap();
230
231 assert_eq!(ok_result.results[0].text, "testing")
232 }
233
234 #[test]
235 fn decode_failed_generation() {
236 let value = json!({
237 "detail": {
238 "msg": "test",
239 "type": "not_implemented"
240 }
241 });
242
243 let err_result = serde_json::from_value::<GenerationErrorResult>(value).unwrap();
244
245 assert_eq!(err_result.detail.message, "test")
246 }
247}