use serde::{Deserialize, Serialize};
use serde_json::Value;
use super::common::{GenerationRequest, UsageInfo};
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
pub struct ClassifyRequest {
pub model: String,
pub input: Value,
pub user: Option<String>,
pub rid: Option<String>,
pub priority: Option<i32>,
}
impl GenerationRequest for ClassifyRequest {
fn is_stream(&self) -> bool {
false }
fn get_model(&self) -> Option<&str> {
Some(&self.model)
}
fn extract_text_for_routing(&self) -> String {
match &self.input {
Value::String(s) => s.clone(),
Value::Array(arr) => arr
.iter()
.filter_map(|v| v.as_str())
.collect::<Vec<_>>()
.join(" "),
_ => String::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
pub struct ClassifyData {
pub index: u32,
pub label: String,
pub probs: Vec<f32>,
pub num_classes: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
pub struct ClassifyResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub data: Vec<ClassifyData>,
pub usage: UsageInfo,
}
impl ClassifyResponse {
pub fn new(
id: String,
model: String,
created: u64,
data: Vec<ClassifyData>,
usage: UsageInfo,
) -> Self {
Self {
id,
object: "list".to_string(),
created,
model,
data,
usage,
}
}
}