use serde::{Deserialize, Serialize};
use serde_json::Value;
use super::common::{GenerationRequest, UsageInfo};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ClassifyRequest {
pub model: String,
pub input: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub rid: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub log_metrics: Option<bool>,
}
impl GenerationRequest for ClassifyRequest {
fn is_stream(&self) -> bool {
false }
fn get_model(&self) -> Option<&str> {
Some(&self.model)
}
fn extract_text_for_routing(&self) -> String {
match &self.input {
Value::String(s) => s.clone(),
Value::Array(arr) => arr
.iter()
.filter_map(|v| v.as_str())
.collect::<Vec<_>>()
.join(" "),
_ => String::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClassifyData {
pub index: u32,
pub label: String,
pub probs: Vec<f32>,
pub num_classes: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClassifyResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub data: Vec<ClassifyData>,
pub usage: UsageInfo,
}
impl ClassifyResponse {
pub fn new(
id: String,
model: String,
created: u64,
data: Vec<ClassifyData>,
usage: UsageInfo,
) -> Self {
Self {
id,
object: "list".to_string(),
created,
model,
data,
usage,
}
}
}