use serde::{Serialize, Deserialize};
use std::collections::HashMap;
use serde_json::{json, Value};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct GenerationRequest {
pub prompt: String, pub task: Option<String>, pub params: Option<HashMap<String, serde_json::Value>>, }
impl Default for GenerationRequest {
fn default() -> Self {
Self {
prompt: String::new(),
task: None,
params: None,
}
}
}
impl GenerationRequest {
pub fn new(prompt: String) -> Self {
GenerationRequest {
prompt,
..Default::default()
}
}
pub fn builder(prompt: impl Into<String>) -> GenerationRequest {
GenerationRequest::new(prompt.into())
}
pub fn task(mut self, name: impl Into<String>) -> Self {
self.task = Some(name.into());
self
}
pub fn param(mut self, key: impl Into<String>, value: impl Into<Value>) -> Self {
self.params
.get_or_insert_with(HashMap::new)
.insert(key.into(), value.into());
self
}
pub fn max_tokens(self, tokens: u32) -> Self {
self.param("max_tokens", json!(tokens))
}
pub fn build(self) -> Self {
self
}
}
#[derive(Clone)]
pub struct LlmManagerRequest {
pub prompt: String,
pub task: Option<String>,
pub params: Option<HashMap<String, serde_json::Value>>,
pub attempts: usize,
pub failed_instances: Vec<usize>,
}
impl LlmManagerRequest {
pub fn from_generation_request(request: GenerationRequest) -> Self {
Self {
prompt: request.prompt,
task: request.task,
params: request.params,
attempts: 0,
failed_instances: Vec::new(),
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct LlmManagerResponse {
pub content: String,
pub success: bool,
pub error: Option<String>,
}