use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use plainllm::client;
mod utils;
#[derive(Serialize, Deserialize, JsonSchema, Debug)]
enum Continent {
Europe,
Americas,
Asia,
}
#[derive(Serialize, Deserialize, JsonSchema, Debug)]
struct Capital {
city_name: String,
country: String,
continent: Continent,
}
#[derive(Serialize, Deserialize, JsonSchema, Debug)]
struct ValidationReply {
is_correct: bool,
mistake: Option<String>,
correction: Option<String>,
}
#[tokio::main]
async fn main() {
utils::init_logger();
let args = utils::get_client_args();
let opts = client::LLMOptions::new().temperature(0.20);
let message = "What is the capital of Italy? Reply with city name, country and continent";
let mut messages = vec![
client::Message::new(
"system",
r#"You are a helpful assistant who replies in JSON. Reply with structured output according to the schema."#,
),
client::Message::new("user", message),
];
println!("Question:\n{}\n", message);
let response: Capital = args
.client
.call_llm_structured(&args.model, messages.clone(), &opts)
.await
.expect("Couldn't get response from llm");
let out = serde_json::to_string(&response).unwrap();
println!("Structured Output:\n{}\n", out);
messages.push(client::Message::new("assistant", &format!("{}", out)));
messages.push(client::Message::new(
"user",
"Is the reply from the assistant above correct or not? If not correct, describe the exact reason why it isn't correct in the `mistake` attribute. Provide the correct answer in the `correction` attribute. Make `correction` and `mistake` be null if answer is already correct. Only try to correct the facts/knowledge/statements, not the syntax or if it's JSON or not. Don't correct syntax, just ensure the facts are correct or not. Enums are valid replacements instead of Strings, don't complain about that.",
));
let response: ValidationReply = args
.client
.call_llm_structured(&args.model, messages.clone(), &opts)
.await
.expect("Couldn't get response from llm");
messages.push(client::Message::new(
"assistant",
&format!("{:#?}", response),
));
println!("Validation:\n{:#?}\n", response);
println!("All Messages:\n{:#?}\n", messages);
}