use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: Role,
#[serde(default)]
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionDefinition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
Mode(String),
Specific { #[serde(rename = "type")] tool_type: String, function: ToolChoiceFunction },
}
#[derive(Debug, Clone, Deserialize)]
pub struct ToolChoiceFunction {
pub name: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ChatCompletionRequest {
#[serde(default)]
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(default = "default_max_tokens")]
pub max_tokens: usize,
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_top_p")]
pub top_p: f32,
#[serde(default)]
pub stream: bool,
#[serde(default)]
pub stop: Option<Vec<String>>,
#[serde(default)]
pub frequency_penalty: f32,
#[serde(default)]
pub presence_penalty: f32,
#[serde(default)]
pub tools: Option<Vec<ToolDefinition>>,
#[serde(default)]
pub tool_choice: Option<ToolChoice>,
}
fn default_max_tokens() -> usize {
256
}
fn default_temperature() -> f32 {
0.7
}
fn default_top_p() -> f32 {
0.9
}
#[derive(Debug, Clone, Serialize)]
pub struct ChatCompletionChoice {
pub index: usize,
pub message: ChatMessage,
pub finish_reason: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatCompletionChoice>,
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize, Default)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
#[derive(Debug, Clone, Serialize)]
pub struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatCompletionChunkChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ChatCompletionChunkChoice {
pub index: usize,
pub delta: ChatCompletionDelta,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ChatCompletionDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<Role>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingRequest {
pub input: EmbeddingInput,
#[serde(default)]
pub model: String,
#[serde(default = "default_encoding_format")]
pub encoding_format: String,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub enum EmbeddingInput {
Single(String),
Batch(Vec<String>),
}
fn default_encoding_format() -> String {
"float".to_string()
}
#[derive(Debug, Clone, Serialize)]
pub struct EmbeddingResponse {
pub object: String,
pub data: Vec<EmbeddingData>,
pub model: String,
pub usage: EmbeddingUsage,
}
#[derive(Debug, Clone, Serialize)]
pub struct EmbeddingData {
pub object: String,
pub embedding: Vec<f32>,
pub index: usize,
}
#[derive(Debug, Clone, Serialize)]
pub struct EmbeddingUsage {
pub prompt_tokens: usize,
pub total_tokens: usize,
}
#[derive(Debug, Clone, Deserialize)]
pub struct LoadModelRequest {
pub model_path: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct LoadModelResponse {
pub status: String,
pub model: String,
pub context_size: usize,
}
#[derive(Debug, Clone, Serialize)]
pub struct QueueStatusResponse {
pub active_requests: usize,
pub queued_requests: usize,
pub max_queue_depth: usize,
pub max_concurrent: usize,
}
#[derive(Debug, Clone, Deserialize)]
pub struct CompletionRequest {
#[serde(default)]
pub model: String,
pub prompt: String,
#[serde(default = "default_max_tokens")]
pub max_tokens: usize,
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_top_p")]
pub top_p: f32,
#[serde(default)]
pub stream: bool,
#[serde(default)]
pub stop: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize)]
pub struct CompletionChoice {
pub text: String,
pub index: usize,
pub finish_reason: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct CompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<CompletionChoice>,
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize)]
pub struct ModelInfo {
pub id: String,
pub object: String,
pub created: u64,
pub owned_by: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct ModelsResponse {
pub object: String,
pub data: Vec<ModelInfo>,
}
#[derive(Debug, Clone, Serialize)]
pub struct HealthResponse {
pub status: String,
pub model: String,
pub context_size: usize,
}
#[derive(Debug, Clone, Serialize)]
pub struct ErrorResponse {
pub error: ErrorDetail,
}
#[derive(Debug, Clone, Serialize)]
pub struct ErrorDetail {
pub message: String,
pub r#type: String,
pub code: Option<String>,
}
impl ErrorResponse {
pub fn new(message: impl Into<String>, error_type: impl Into<String>) -> Self {
Self {
error: ErrorDetail {
message: message.into(),
r#type: error_type.into(),
code: None,
},
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RetrieveRequest {
pub knowledge_base_id: String,
pub query: String,
#[serde(default)]
pub retrieval_configuration: Option<RetrievalConfiguration>,
#[serde(default)]
pub next_token: Option<String>,
}
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RetrievalConfiguration {
#[serde(default)]
pub vector_search_configuration: Option<VectorSearchConfiguration>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct VectorSearchConfiguration {
#[serde(default = "default_num_results")]
pub number_of_results: usize,
#[serde(default)]
pub override_search_type: Option<String>,
#[serde(default)]
pub filter: Option<RetrievalFilter>,
}
fn default_num_results() -> usize {
5
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RetrievalFilter {
#[serde(default)]
pub and_all: Option<Vec<RetrievalFilter>>,
#[serde(default)]
pub or_all: Option<Vec<RetrievalFilter>>,
#[serde(default)]
pub equals: Option<FilterCondition>,
#[serde(default)]
pub not_equals: Option<FilterCondition>,
#[serde(default)]
pub greater_than: Option<FilterCondition>,
#[serde(default)]
pub less_than: Option<FilterCondition>,
#[serde(default)]
pub string_contains: Option<FilterCondition>,
#[serde(default)]
pub starts_with: Option<FilterCondition>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct FilterCondition {
pub key: String,
pub value: serde_json::Value,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct RetrieveResponse {
pub retrieval_results: Vec<RetrievalResult>,
#[serde(skip_serializing_if = "Option::is_none")]
pub next_token: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct RetrievalResult {
pub content: RetrievalResultContent,
pub location: RetrievalResultLocation,
pub score: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize)]
pub struct RetrievalResultContent {
pub text: String,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct RetrievalResultLocation {
#[serde(rename = "type")]
pub location_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub s3_location: Option<S3Location>,
#[serde(skip_serializing_if = "Option::is_none")]
pub custom_location: Option<CustomLocation>,
}
#[derive(Debug, Clone, Serialize)]
pub struct S3Location {
pub uri: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct CustomLocation {
pub uri: String,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RetrieveAndGenerateRequest {
pub input: RetrieveAndGenerateInput,
pub retrieve_and_generate_configuration: RetrieveAndGenerateConfiguration,
#[serde(default)]
pub session_id: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct RetrieveAndGenerateInput {
pub text: String,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RetrieveAndGenerateConfiguration {
#[serde(rename = "type")]
pub config_type: String,
pub knowledge_base_configuration: KnowledgeBaseConfiguration,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct KnowledgeBaseConfiguration {
pub knowledge_base_id: String,
#[serde(default)]
pub model_arn: Option<String>,
#[serde(default)]
pub retrieval_configuration: Option<RetrievalConfiguration>,
#[serde(default)]
pub generation_configuration: Option<GenerationConfiguration>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerationConfiguration {
#[serde(default)]
pub prompt_template: Option<PromptTemplate>,
#[serde(default)]
pub inference_config: Option<InferenceConfig>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptTemplate {
pub text_prompt_template: String,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InferenceConfig {
#[serde(default)]
pub text_inference_config: Option<TextInferenceConfig>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TextInferenceConfig {
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_top_p")]
pub top_p: f32,
#[serde(default = "default_max_tokens")]
pub max_tokens: usize,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct RetrieveAndGenerateResponse {
pub output: RetrieveAndGenerateOutput,
pub citations: Vec<Citation>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct RetrieveAndGenerateOutput {
pub text: String,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct Citation {
#[serde(skip_serializing_if = "Option::is_none")]
pub generated_response_part: Option<GeneratedResponsePart>,
pub retrieved_references: Vec<RetrievedReference>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GeneratedResponsePart {
pub text_response_part: TextResponsePart,
}
#[derive(Debug, Clone, Serialize)]
pub struct TextResponsePart {
pub text: String,
pub span: Option<TextSpan>,
}
#[derive(Debug, Clone, Serialize)]
pub struct TextSpan {
pub start: usize,
pub end: usize,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct RetrievedReference {
pub content: RetrievalResultContent,
pub location: RetrievalResultLocation,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct IngestRequest {
pub knowledge_base_id: String,
pub documents: Vec<IngestDocument>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct IngestDocument {
pub document_id: String,
pub content: DocumentContent,
#[serde(default)]
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct DocumentContent {
pub text: String,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct IngestResponse {
pub documents_ingested: usize,
pub chunks_created: usize,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub failures: Vec<IngestFailure>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct IngestFailure {
pub document_id: String,
pub error_message: String,
}
#[derive(Debug, Clone, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct ListKnowledgeBasesRequest {
#[serde(default)]
pub max_results: Option<usize>,
#[serde(default)]
pub next_token: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ListKnowledgeBasesResponse {
pub knowledge_base_summaries: Vec<KnowledgeBaseSummary>,
#[serde(skip_serializing_if = "Option::is_none")]
pub next_token: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct KnowledgeBaseSummary {
pub knowledge_base_id: String,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub status: String,
pub updated_at: String,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GetKnowledgeBaseResponse {
pub knowledge_base: KnowledgeBaseDetail,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct KnowledgeBaseDetail {
pub knowledge_base_id: String,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub status: String,
pub storage_configuration: StorageConfigurationResponse,
pub updated_at: String,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct StorageConfigurationResponse {
#[serde(rename = "type")]
pub storage_type: String,
pub vector_dimension: usize,
}