use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct ChatRequest {
model: String,
messages: Vec<Message>,
stream: bool,
}
#[derive(Debug, Serialize, Deserialize)]
struct PullRequest {
name: String,
stream: bool,
}
#[derive(Debug, Serialize, Deserialize)]
struct EmbRequest {
model: String,
prompt: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct PushRequest {
name: String,
stream: bool,
}
#[tokio::main]
pub async fn chat_completion(model: &str, content: &str, role: &str) -> Result<(), reqwest::Error> {
let req = ChatRequest {
model: String::from(model),
messages: vec![Message {
role: role.to_string(),
content: content.to_string(),
}],
stream: false,
};
let response = reqwest::Client::new()
.post("http://localhost:11434/api/chat")
.json(&req)
.send()
.await?;
let response_json: serde_json::Value = response.json().await?;
println!("{:#?}", response_json);
Ok(())
}
#[tokio::main]
pub async fn pull_model(name: &str, stream_mode: bool) -> Result<(), reqwest::Error> {
let req = PullRequest {
name: String::from(name),
stream: stream_mode,
};
let response = reqwest::Client::new()
.post("http://localhost:11434/api/pull")
.json(&req)
.send()
.await?;
let response_json: serde_json::Value = response.json().await?;
println!("{:#?}", response_json);
Ok(())
}
#[tokio::main]
pub async fn gen_embeddings(model: &str, prompt: &str) -> Result<(), reqwest::Error> {
let req = EmbRequest {
model: String::from(model),
prompt: String::from(prompt),
};
let response = reqwest::Client::new()
.post("http://localhost:11434/api/embeddings")
.json(&req)
.send()
.await?;
let response_json: serde_json::Value = response.json().await?;
println!("{:#?}", response_json);
Ok(())
}
#[tokio::main]
pub async fn list_models() -> Result<(), reqwest::Error> {
let response = reqwest::get("http://localhost:11434/api/ps").await?;
println!("{:#?}", response);
Ok(())
}
#[tokio::main]
pub async fn push_models(name: &str, stream_mode: bool) -> Result<(), reqwest::Error> {
let req = PushRequest {
name: String::from(name),
stream: stream_mode,
};
let response = reqwest::Client::new()
.post("http://localhost:11434/api/push")
.json(&req)
.send()
.await?;
let response_json: serde_json::Value = response.json().await?;
println!("{:#?}", response_json);
Ok(())
}
#[cfg(test)]
mod tests{
use super::*;
#[test]
fn chat_test(){
let _ = chat_completion("model_name", "Hello!", "user");
}
#[test]
fn pull_test(){
let _ = pull_model("model_name", false);
}
#[test]
fn gen_embed_test(){
let _ = gen_embeddings("model_name", "Generate embeddings from this prompt");
}
#[test]
fn listing(){
let _ = list_models();
}
#[test]
fn pushing(){
let _ = push_models("model_name", true);
}
}