use std::collections::BTreeMap;
use http::Method;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::generated::endpoints;
use super::{
CompletionsResource, DeleteResponse, EmbeddingResponse, EmbeddingsResource, JsonRequestBuilder,
ListRequestBuilder, Model, ModelsResource, ModerationsResource, encode_path_segment,
};
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Completion {
pub id: String,
#[serde(default)]
pub object: String,
pub created: Option<i64>,
#[serde(default)]
pub model: String,
#[serde(default)]
pub choices: Vec<CompletionChoice>,
pub usage: Option<CompletionUsage>,
pub system_fingerprint: Option<String>,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct CompletionChoice {
pub finish_reason: Option<String>,
pub index: u32,
pub logprobs: Option<CompletionLogProbs>,
#[serde(default)]
pub text: String,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct CompletionLogProbs {
#[serde(default)]
pub text_offset: Vec<i64>,
#[serde(default)]
pub token_logprobs: Vec<f64>,
#[serde(default)]
pub tokens: Vec<String>,
#[serde(default)]
pub top_logprobs: Vec<BTreeMap<String, f64>>,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct CompletionUsage {
pub completion_tokens: u64,
pub prompt_tokens: u64,
pub total_tokens: u64,
pub prompt_tokens_details: Option<CompletionUsagePromptTokensDetails>,
pub completion_tokens_details: Option<CompletionUsageCompletionTokensDetails>,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct CompletionUsagePromptTokensDetails {
pub audio_tokens: Option<u64>,
pub cached_tokens: Option<u64>,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct CompletionUsageCompletionTokensDetails {
pub accepted_prediction_tokens: Option<u64>,
pub audio_tokens: Option<u64>,
pub reasoning_tokens: Option<u64>,
pub rejected_prediction_tokens: Option<u64>,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ModerationCreateResponse {
pub id: String,
#[serde(default)]
pub model: String,
#[serde(default)]
pub results: Vec<ModerationResult>,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ModerationResult {
#[serde(default)]
pub categories: BTreeMap<String, bool>,
#[serde(default)]
pub category_applied_input_types: BTreeMap<String, Vec<String>>,
#[serde(default)]
pub category_scores: BTreeMap<String, f64>,
#[serde(default)]
pub flagged: bool,
#[serde(flatten)]
pub extra: BTreeMap<String, Value>,
}
impl CompletionsResource {
pub fn create(&self) -> JsonRequestBuilder<Completion> {
let endpoint = endpoints::core::COMPLETIONS_CREATE;
JsonRequestBuilder::new(
self.client.clone(),
endpoint.id,
Method::POST,
endpoint.template,
)
}
}
impl EmbeddingsResource {
pub fn create(&self) -> JsonRequestBuilder<EmbeddingResponse> {
JsonRequestBuilder::new(
self.client.clone(),
"embeddings.create",
Method::POST,
"/embeddings",
)
}
}
impl ModerationsResource {
pub fn create(&self) -> JsonRequestBuilder<ModerationCreateResponse> {
let endpoint = endpoints::core::MODERATIONS_CREATE;
JsonRequestBuilder::new(
self.client.clone(),
endpoint.id,
Method::POST,
endpoint.template,
)
}
}
impl ModelsResource {
pub fn list(&self) -> ListRequestBuilder<Model> {
ListRequestBuilder::new(self.client.clone(), "models.list", "/models")
}
pub fn retrieve(&self, model_id: impl Into<String>) -> JsonRequestBuilder<Model> {
JsonRequestBuilder::new(
self.client.clone(),
"models.retrieve",
Method::GET,
format!("/models/{}", encode_path_segment(model_id.into())),
)
}
pub fn delete(&self, model_id: impl Into<String>) -> JsonRequestBuilder<DeleteResponse> {
JsonRequestBuilder::new(
self.client.clone(),
"models.delete",
Method::DELETE,
format!("/models/{}", encode_path_segment(model_id.into())),
)
}
}