use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::{EmbedModel, Truncate};
#[derive(Serialize, Default, Debug)]
pub struct ClassifyRequest<'input> {
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<EmbedModel>,
#[serde(skip_serializing_if = "Option::is_none")]
pub preset: Option<String>,
pub inputs: &'input [String],
pub examples: &'input [ClassifyExample<'input>],
pub truncate: Option<Truncate>,
}
#[derive(Serialize, Debug)]
pub struct ClassifyExample<'input> {
pub text: &'input str,
pub label: &'input str,
}
#[derive(Deserialize, Debug)]
pub struct Confidence {
pub label: String,
pub confidence: f64,
}
#[derive(Deserialize, Debug, PartialEq)]
pub struct LabelProperties {
pub confidence: f64,
}
#[derive(Deserialize, Debug, PartialEq)]
pub struct Classification {
pub id: String,
pub prediction: String,
pub confidence: f32,
pub labels: HashMap<String, LabelProperties>,
pub input: String,
}
#[derive(Deserialize, Debug)]
pub(crate) struct ClassifyResponse {
pub classifications: Vec<Classification>,
}