use crate::{
chat::{
ChatCompletionRequest, ChatCompletionRequestMessage, ChatCompletionRequestSampling,
ChatResponseFormat, StreamOptions, Tool, ToolChoice,
},
embeddings::EmbeddingRequest,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagEmbeddingRequest {
#[serde(rename = "embeddings")]
pub embedding_request: EmbeddingRequest,
#[serde(rename = "url")]
pub qdrant_url: String,
#[serde(rename = "collection_name")]
pub qdrant_collection_name: String,
}
impl RagEmbeddingRequest {
pub fn new(
input: &[String],
qdrant_url: impl AsRef<str>,
qdrant_collection_name: impl AsRef<str>,
) -> Self {
RagEmbeddingRequest {
embedding_request: EmbeddingRequest {
model: "dummy-embedding-model".to_string(),
input: input.into(),
encoding_format: None,
user: None,
},
qdrant_url: qdrant_url.as_ref().to_string(),
qdrant_collection_name: qdrant_collection_name.as_ref().to_string(),
}
}
pub fn from_embedding_request(
embedding_request: EmbeddingRequest,
qdrant_url: impl AsRef<str>,
qdrant_collection_name: impl AsRef<str>,
) -> Self {
RagEmbeddingRequest {
embedding_request,
qdrant_url: qdrant_url.as_ref().to_string(),
qdrant_collection_name: qdrant_collection_name.as_ref().to_string(),
}
}
}
#[test]
fn test_rag_serialize_embedding_request() {
let embedding_request = EmbeddingRequest {
model: "model".to_string(),
input: "Hello, world!".into(),
encoding_format: None,
user: None,
};
let qdrant_url = "http://localhost:6333".to_string();
let qdrant_collection_name = "qdrant_collection_name".to_string();
let rag_embedding_request = RagEmbeddingRequest {
embedding_request,
qdrant_url,
qdrant_collection_name,
};
let json = serde_json::to_string(&rag_embedding_request).unwrap();
assert_eq!(
json,
r#"{"embeddings":{"model":"model","input":"Hello, world!"},"url":"http://localhost:6333","collection_name":"qdrant_collection_name"}"#
);
}
#[test]
fn test_rag_deserialize_embedding_request() {
let json = r#"{"embeddings":{"model":"model","input":["Hello, world!"]},"url":"http://localhost:6333","collection_name":"qdrant_collection_name"}"#;
let rag_embedding_request: RagEmbeddingRequest = serde_json::from_str(json).unwrap();
assert_eq!(rag_embedding_request.qdrant_url, "http://localhost:6333");
assert_eq!(
rag_embedding_request.qdrant_collection_name,
"qdrant_collection_name"
);
assert_eq!(rag_embedding_request.embedding_request.model, "model");
assert_eq!(
rag_embedding_request.embedding_request.input,
vec!["Hello, world!"].into()
);
}
#[derive(Debug, Deserialize, Serialize)]
pub struct RagChatCompletionsRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub chat_model: Option<String>,
pub messages: Vec<ChatCompletionRequestMessage>,
pub embedding_model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<String>,
pub qdrant_url: String,
pub qdrant_collection_name: String,
pub limit: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n_choice: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, f64>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ChatResponseFormat>,
pub tools: Option<Vec<Tool>>,
pub tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub context_window: Option<u64>,
}
impl RagChatCompletionsRequest {
pub fn as_chat_completions_request(&self) -> ChatCompletionRequest {
ChatCompletionRequest {
model: self.chat_model.clone(),
messages: self.messages.clone(),
temperature: self.temperature,
top_p: self.top_p,
n_choice: self.n_choice,
stream: self.stream,
stream_options: self.stream_options.clone(),
stop: self.stop.clone(),
max_tokens: self.max_tokens,
presence_penalty: self.presence_penalty,
frequency_penalty: self.frequency_penalty,
logit_bias: self.logit_bias.clone(),
user: self.user.clone(),
functions: None,
function_call: None,
response_format: self.response_format.clone(),
tool_choice: self.tool_choice.clone(),
tools: self.tools.clone(),
context_window: self.context_window,
}
}
pub fn from_chat_completions_request(
chat_completions_request: ChatCompletionRequest,
qdrant_url: impl Into<String>,
qdrant_collection_name: impl Into<String>,
limit: u64,
) -> Self {
RagChatCompletionsRequest {
chat_model: chat_completions_request.model,
messages: chat_completions_request.messages,
embedding_model: "dummy-embedding-model".to_string(),
encoding_format: None,
qdrant_url: qdrant_url.into(),
qdrant_collection_name: qdrant_collection_name.into(),
limit,
temperature: chat_completions_request.temperature,
top_p: chat_completions_request.top_p,
n_choice: chat_completions_request.n_choice,
stream: chat_completions_request.stream,
stream_options: chat_completions_request.stream_options,
stop: chat_completions_request.stop,
max_tokens: chat_completions_request.max_tokens,
presence_penalty: chat_completions_request.presence_penalty,
frequency_penalty: chat_completions_request.frequency_penalty,
logit_bias: chat_completions_request.logit_bias,
user: chat_completions_request.user,
response_format: chat_completions_request.response_format,
tool_choice: chat_completions_request.tool_choice,
tools: chat_completions_request.tools,
context_window: chat_completions_request.context_window,
}
}
}
pub struct RagChatCompletionRequestBuilder {
req: RagChatCompletionsRequest,
}
impl RagChatCompletionRequestBuilder {
pub fn new(
messages: Vec<ChatCompletionRequestMessage>,
qdrant_url: impl Into<String>,
qdrant_collection_name: impl Into<String>,
limit: u64,
) -> Self {
Self {
req: RagChatCompletionsRequest {
chat_model: Some("dummy-chat-model".to_string()),
messages,
embedding_model: "dummy-embedding-model".to_string(),
encoding_format: Some("float".to_string()),
qdrant_url: qdrant_url.into(),
qdrant_collection_name: qdrant_collection_name.into(),
limit,
temperature: Some(1.0),
top_p: Some(1.0),
n_choice: Some(1),
stream: Some(false),
stream_options: None,
stop: None,
max_tokens: Some(1024),
presence_penalty: Some(0.0),
frequency_penalty: Some(0.0),
logit_bias: None,
user: None,
response_format: None,
tool_choice: None,
tools: None,
context_window: Some(1),
},
}
}
pub fn with_sampling(mut self, sampling: ChatCompletionRequestSampling) -> Self {
let (temperature, top_p) = match sampling {
ChatCompletionRequestSampling::Temperature(t) => (t, 1.0),
ChatCompletionRequestSampling::TopP(p) => (1.0, p),
};
self.req.temperature = Some(temperature);
self.req.top_p = Some(top_p);
self
}
pub fn with_n_choices(mut self, n: u64) -> Self {
let n_choice = if n < 1 { 1 } else { n };
self.req.n_choice = Some(n_choice);
self
}
pub fn with_stream(mut self, flag: bool) -> Self {
self.req.stream = Some(flag);
self
}
pub fn with_stop(mut self, stop: Vec<String>) -> Self {
self.req.stop = Some(stop);
self
}
pub fn with_max_tokens(mut self, max_tokens: u64) -> Self {
let max_tokens = if max_tokens < 1 { 16 } else { max_tokens };
self.req.max_tokens = Some(max_tokens);
self
}
pub fn with_presence_penalty(mut self, penalty: f64) -> Self {
self.req.presence_penalty = Some(penalty);
self
}
pub fn with_frequency_penalty(mut self, penalty: f64) -> Self {
self.req.frequency_penalty = Some(penalty);
self
}
pub fn with_logits_bias(mut self, map: HashMap<String, f64>) -> Self {
self.req.logit_bias = Some(map);
self
}
pub fn with_user(mut self, user: impl Into<String>) -> Self {
self.req.user = Some(user.into());
self
}
pub fn with_context_window(mut self, context_window: u64) -> Self {
self.req.context_window = Some(context_window);
self
}
pub fn build(self) -> RagChatCompletionsRequest {
self.req
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunksRequest {
pub id: String,
pub filename: String,
pub chunk_capacity: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunksResponse {
pub id: String,
pub filename: String,
pub chunks: Vec<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RetrieveObject {
#[serde(skip_serializing_if = "Option::is_none")]
pub points: Option<Vec<RagScoredPoint>>,
pub limit: usize,
pub score_threshold: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagScoredPoint {
pub source: String,
pub score: f32,
}
#[test]
fn test_rag_serialize_retrieve_object() {
{
let ro = RetrieveObject {
points: Some(vec![RagScoredPoint {
source: "source".to_string(),
score: 0.5,
}]),
limit: 1,
score_threshold: 0.5,
};
let json = serde_json::to_string(&ro).unwrap();
assert_eq!(
json,
r#"{"points":[{"source":"source","score":0.5}],"limit":1,"score_threshold":0.5}"#
);
}
{
let ro = RetrieveObject {
points: None,
limit: 1,
score_threshold: 0.5,
};
let json = serde_json::to_string(&ro).unwrap();
assert_eq!(json, r#"{"limit":1,"score_threshold":0.5}"#);
}
}
#[test]
fn test_rag_deserialize_retrieve_object() {
{
let json =
r#"{"points":[{"source":"source","score":0.5}],"limit":1,"score_threshold":0.5}"#;
let ro: RetrieveObject = serde_json::from_str(json).unwrap();
assert_eq!(ro.limit, 1);
assert_eq!(ro.score_threshold, 0.5);
assert!(ro.points.is_some());
let points = ro.points.unwrap();
assert_eq!(points.len(), 1);
assert_eq!(points[0].source, "source");
assert_eq!(points[0].score, 0.5);
}
{
let json = r#"{"limit":1,"score_threshold":0.5}"#;
let ro: RetrieveObject = serde_json::from_str(json).unwrap();
assert_eq!(ro.limit, 1);
assert_eq!(ro.score_threshold, 0.5);
assert!(ro.points.is_none());
}
}