use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::fmt::Display;
use thiserror::Error;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vibesort_config() {
let sorter = Vibesort::new("key", "model", "url");
assert_eq!(sorter.api_key, "key");
assert_eq!(sorter.model, "model");
assert_eq!(sorter.base_url, "url");
}
#[tokio::test]
async fn test_vibesort_with_mock() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let mock_server = MockServer::start().await;
let base_url = mock_server.uri();
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"choices": [{
"message": {
"content": "[1,1,2,3,4,5,6,9]"
}
}]
})))
.mount(&mock_server)
.await;
let sorter = Vibesort::new("test-api-key", "test-model", base_url.as_str());
let numbers = vec![3, 1, 4, 1, 5, 9, 2, 6];
let result = sorter.sort(&numbers).await;
assert!(result.is_ok());
let sorted = result.unwrap();
assert_eq!(sorted, vec![1, 1, 2, 3, 4, 5, 6, 9]);
}
#[tokio::test]
async fn test_vibesort_api_error() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let mock_server = MockServer::start().await;
let base_url = mock_server.uri();
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error"))
.mount(&mock_server)
.await;
let sorter = Vibesort::new("test-api-key", "test-model", base_url.as_str());
let numbers = vec![3, 1, 4];
let result = sorter.sort(&numbers).await;
assert!(result.is_err());
match result.unwrap_err() {
VibesortError::ApiError(_) => {}
_ => panic!("Expected ApiError"),
}
}
#[tokio::test]
async fn test_vibesort_parse_error() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let mock_server = MockServer::start().await;
let base_url = mock_server.uri();
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"choices": [{
"message": {
"content": "Here is the sorted array: 1, 2, 3"
}
}]
})))
.mount(&mock_server)
.await;
let sorter = Vibesort::new("test-api-key", "test-model", base_url.as_str());
let numbers = vec![3, 1, 2];
let result = sorter.sort(&numbers).await;
assert!(result.is_err());
match result.unwrap_err() {
VibesortError::ParseError(msg) => {
assert!(msg.contains("Here is the sorted array: 1, 2, 3"));
}
_ => panic!("Expected ParseError"),
}
}
#[tokio::test]
async fn test_vibesort_str_with_mock() {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let mock_server = MockServer::start().await;
let base_url = mock_server.uri();
Mock::given(method("POST"))
.and(path("/chat/completions"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"choices": [{
"message": {
"content": "[\"apple\",\"banana\",\"cherry\"]"
}
}]
})))
.mount(&mock_server)
.await;
let sorter = Vibesort::new("test-api-key", "test-model", base_url.as_str());
let words = vec!["banana", "apple", "cherry"];
let result = sorter.sort_str(&words).await;
assert!(result.is_ok());
let sorted = result.unwrap();
assert_eq!(sorted, vec!["apple", "banana", "cherry"]);
}
}
#[derive(Error, Debug)]
pub enum VibesortError {
#[error("HTTP request failed: {0}")]
HttpError(#[from] reqwest::Error),
#[error("JSON parsing failed: {0}")]
JsonError(#[from] serde_json::Error),
#[error("LLM API error: {0}")]
ApiError(String),
#[error("Invalid response format from LLM")]
InvalidResponse,
#[error("Failed to parse LLM response as sorted array. LLM returned: {0}")]
ParseError(String),
}
#[derive(Debug, Serialize)]
struct ChatRequest<'a> {
model: &'a str,
messages: Vec<ChatMessage<'a>>,
temperature: f32,
}
#[derive(Debug, Serialize)]
struct ChatMessage<'a> {
role: &'a str,
content: &'a str,
}
#[derive(Debug, Deserialize)]
struct ChatMessageResponse {
content: String,
}
#[derive(Debug, Deserialize)]
struct ChatResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: ChatMessageResponse,
}
#[derive(Debug, Clone)]
pub struct Vibesort<'a> {
pub api_key: &'a str,
pub model: &'a str,
pub base_url: &'a str,
}
impl<'a> Vibesort<'a> {
pub fn new(api_key: &'a str, model: &'a str, base_url: &'a str) -> Self {
Self {
api_key,
model,
base_url,
}
}
pub async fn sort<T>(&self, items: &[T]) -> Result<Vec<T>, VibesortError>
where
T: Display + Serialize + DeserializeOwned,
{
let json_array = serde_json::to_string(items)?;
let url = format!("{}/chat/completions", self.base_url);
let client = reqwest::Client::new();
let system_prompt = "You are a helpful assistant that sorts arrays. Sort the following JSON array with ascending order and return ONLY the sorted JSON array, nothing else.";
let request = ChatRequest {
model: self.model,
messages: vec![
ChatMessage {
role: "system",
content: system_prompt,
},
ChatMessage {
role: "user",
content: &json_array,
},
],
temperature: 0.0, };
let response = client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(VibesortError::ApiError(format!(
"API returned status {}\nServer response: {}",
status, error_text
)));
}
let chat_response: ChatResponse = response.json().await?;
let mut sorted_json = chat_response
.choices
.first()
.ok_or(VibesortError::InvalidResponse)?
.message
.content
.trim()
.to_string();
if sorted_json.starts_with("```") {
if let Some(start_idx) = sorted_json.find('\n') {
sorted_json = sorted_json[start_idx + 1..].to_string();
} else {
sorted_json = sorted_json[3..].to_string();
}
if sorted_json.ends_with("```") {
sorted_json = sorted_json[..sorted_json.len() - 3].trim().to_string();
}
}
let sorted: Vec<T> = serde_json::from_str(&sorted_json).map_err(|e| {
VibesortError::ParseError(format!(
"Failed to parse as JSON array: {}\nLLM returned: {}",
e, sorted_json
))
})?;
Ok(sorted)
}
pub async fn sort_str(&self, items: &[&str]) -> Result<Vec<String>, VibesortError> {
let string_vec: Vec<String> = items.iter().map(|s| s.to_string()).collect();
self.sort(&string_vec).await
}
}