use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, thiserror::Error)]
pub enum CustomEndpointError {
#[error("HTTP request failed: {0}")]
HttpError(#[from] reqwest::Error),
#[error("Invalid response format: {0}")]
ParseError(String),
#[error("Authentication failed")]
AuthError,
#[error("Endpoint unreachable: {0}")]
Unreachable(String),
#[error("Configuration error: {0}")]
ConfigError(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CustomEndpointAuth {
ApiKey(String),
HeaderAuth {
header: String,
value: String,
},
None,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum EndpointRequestFormat {
OpenAiCompatible,
AnthropicCompatible,
CustomJson(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CustomEndpointConfig {
pub endpoint_id: Uuid,
pub endpoint_url: String,
pub auth: CustomEndpointAuth,
pub model_name: String,
pub max_tokens: Option<u32>,
pub temperature: f32,
pub timeout_secs: u64,
pub request_format: EndpointRequestFormat,
}
impl CustomEndpointConfig {
#[must_use]
pub fn new(
endpoint_url: impl Into<String>,
auth: CustomEndpointAuth,
model_name: impl Into<String>,
request_format: EndpointRequestFormat,
) -> Self {
Self {
endpoint_id: Uuid::new_v4(),
endpoint_url: endpoint_url.into(),
auth,
model_name: model_name.into(),
max_tokens: None,
temperature: 0.7,
timeout_secs: 30,
request_format,
}
}
}
#[derive(Debug)]
pub struct CustomEndpointClient {
pub config: CustomEndpointConfig,
http_client: reqwest::Client,
}
impl CustomEndpointClient {
pub fn new(config: CustomEndpointConfig) -> Self {
let timeout = std::time::Duration::from_secs(config.timeout_secs);
let http_client = reqwest::Client::builder()
.timeout(timeout)
.build()
.unwrap_or_default();
Self {
config,
http_client,
}
}
pub fn build_request_body(
&self,
prompt: &str,
system: Option<&str>,
) -> Result<Value, CustomEndpointError> {
let body = match &self.config.request_format {
EndpointRequestFormat::OpenAiCompatible => {
let mut messages: Vec<Value> = Vec::new();
if let Some(sys) = system {
messages.push(json!({"role": "system", "content": sys}));
}
messages.push(json!({"role": "user", "content": prompt}));
let mut body = json!({
"model": self.config.model_name,
"messages": messages,
"temperature": self.config.temperature,
});
if let Some(max) = self.config.max_tokens {
body["max_tokens"] = json!(max);
}
body
}
EndpointRequestFormat::AnthropicCompatible => {
let mut messages: Vec<Value> = Vec::new();
messages.push(json!({"role": "user", "content": prompt}));
let mut body = json!({
"model": self.config.model_name,
"messages": messages,
"max_tokens": self.config.max_tokens.unwrap_or(1024),
"temperature": self.config.temperature,
});
if let Some(sys) = system {
body["system"] = json!(sys);
}
body
}
EndpointRequestFormat::CustomJson(template) => {
let rendered = template.replace("{prompt}", prompt);
serde_json::from_str(&rendered).map_err(|e| {
CustomEndpointError::ConfigError(format!(
"Custom JSON template is not valid JSON after substitution: {e}"
))
})?
}
};
Ok(body)
}
fn apply_auth(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
match &self.config.auth {
CustomEndpointAuth::ApiKey(key) => {
builder.header("Authorization", format!("Bearer {key}"))
}
CustomEndpointAuth::HeaderAuth { header, value } => {
builder.header(header.as_str(), value.as_str())
}
CustomEndpointAuth::None => builder,
}
}
fn extract_text(response_json: &Value) -> Result<String, CustomEndpointError> {
if let Some(text) = response_json
.get("choices")
.and_then(|c| c.get(0))
.and_then(|c| c.get("message"))
.and_then(|m| m.get("content"))
.and_then(|v| v.as_str())
{
return Ok(text.to_string());
}
if let Some(text) = response_json
.get("content")
.and_then(|c| c.get(0))
.and_then(|c| c.get("text"))
.and_then(|v| v.as_str())
{
return Ok(text.to_string());
}
if let Some(text) = response_json.get("response").and_then(|v| v.as_str()) {
return Ok(text.to_string());
}
Err(CustomEndpointError::ParseError(
"Response contained no recognisable text field (checked OpenAI, Anthropic, and generic formats)".to_string(),
))
}
pub async fn complete(
&self,
prompt: &str,
system: Option<&str>,
) -> Result<String, CustomEndpointError> {
let body = self.build_request_body(prompt, system)?;
let request = self
.http_client
.post(&self.config.endpoint_url)
.header("Content-Type", "application/json");
let request = self.apply_auth(request);
let response = request.json(&body).send().await?;
if response.status() == reqwest::StatusCode::UNAUTHORIZED
|| response.status() == reqwest::StatusCode::FORBIDDEN
{
return Err(CustomEndpointError::AuthError);
}
if !response.status().is_success() {
return Err(CustomEndpointError::Unreachable(format!(
"Endpoint returned HTTP {}",
response.status()
)));
}
let response_json: Value = response.json().await?;
Self::extract_text(&response_json)
}
pub async fn health_check(&self) -> bool {
match self.complete("ping", None).await {
Ok(_) => true,
Err(CustomEndpointError::HttpError(_) | CustomEndpointError::Unreachable(_)) => false,
Err(CustomEndpointError::AuthError | CustomEndpointError::ParseError(_)) => false,
Err(CustomEndpointError::ConfigError(_)) => false,
}
}
}
#[derive(Debug)]
pub struct CustomEndpointRegistry {
endpoints: HashMap<Uuid, Vec<CustomEndpointConfig>>,
}
impl Default for CustomEndpointRegistry {
fn default() -> Self {
Self::new()
}
}
impl CustomEndpointRegistry {
#[must_use]
pub fn new() -> Self {
Self {
endpoints: HashMap::new(),
}
}
pub fn register(&mut self, token_id: Uuid, config: CustomEndpointConfig) -> Uuid {
let endpoint_id = config.endpoint_id;
self.endpoints.entry(token_id).or_default().push(config);
endpoint_id
}
#[must_use]
pub fn get_endpoints(&self, token_id: &Uuid) -> &[CustomEndpointConfig] {
self.endpoints.get(token_id).map_or(&[], |v| v.as_slice())
}
pub fn remove(&mut self, token_id: &Uuid, endpoint_url: &str) -> bool {
if let Some(list) = self.endpoints.get_mut(token_id) {
let before = list.len();
list.retain(|c| c.endpoint_url != endpoint_url);
let removed = list.len() < before;
if list.is_empty() {
self.endpoints.remove(token_id);
}
return removed;
}
false
}
#[must_use]
pub fn endpoint_count(&self, token_id: &Uuid) -> usize {
self.endpoints.get(token_id).map_or(0, |v| v.len())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum ResponseStyle {
Concise,
Detailed,
Technical,
Friendly,
}
impl ResponseStyle {
fn instruction(&self) -> &'static str {
match self {
Self::Concise => "Keep your responses brief and to the point.",
Self::Detailed => {
"Provide comprehensive, well-structured answers with relevant context."
}
Self::Technical => {
"Use precise technical terminology and provide accurate, in-depth explanations."
}
Self::Friendly => {
"Use warm, accessible language and explain concepts in a friendly manner."
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IssuerPersonalization {
pub token_id: Uuid,
pub system_prompt_override: Option<String>,
pub context_documents: Vec<String>,
pub preferred_response_style: ResponseStyle,
}
impl IssuerPersonalization {
#[must_use]
pub fn new(token_id: Uuid) -> Self {
Self {
token_id,
system_prompt_override: None,
context_documents: Vec::new(),
preferred_response_style: ResponseStyle::Detailed,
}
}
#[must_use]
pub fn build_system_prompt(&self, base_prompt: &str) -> String {
let effective_base = self
.system_prompt_override
.as_deref()
.unwrap_or(base_prompt);
let mut parts: Vec<String> = Vec::new();
if !self.context_documents.is_empty() {
parts.push("## Issuer Context\n".to_string());
for (i, doc) in self.context_documents.iter().enumerate() {
parts.push(format!("{}. {}", i + 1, doc));
}
parts.push(String::new()); }
parts.push(effective_base.to_string());
parts.push(String::new());
parts.push(self.preferred_response_style.instruction().to_string());
parts.join("\n")
}
}
#[cfg(test)]
mod tests {
use super::*;
fn openai_config() -> CustomEndpointConfig {
CustomEndpointConfig {
endpoint_id: Uuid::new_v4(),
endpoint_url: "https://api.example.com/v1/chat/completions".to_string(),
auth: CustomEndpointAuth::ApiKey("test-key".to_string()),
model_name: "gpt-4".to_string(),
max_tokens: Some(512),
temperature: 0.7,
timeout_secs: 30,
request_format: EndpointRequestFormat::OpenAiCompatible,
}
}
fn anthropic_config() -> CustomEndpointConfig {
let mut cfg = openai_config();
cfg.model_name = "claude-3-opus-20240229".to_string();
cfg.request_format = EndpointRequestFormat::AnthropicCompatible;
cfg
}
#[test]
fn test_config_defaults() {
let cfg = CustomEndpointConfig::new(
"https://api.example.com/chat",
CustomEndpointAuth::None,
"my-model",
EndpointRequestFormat::OpenAiCompatible,
);
assert_eq!(cfg.temperature, 0.7);
assert_eq!(cfg.timeout_secs, 30);
assert!(cfg.max_tokens.is_none());
}
#[test]
fn test_openai_format_body() {
let client = CustomEndpointClient::new(openai_config());
let body = client
.build_request_body("Hello world", Some("You are a helpful assistant"))
.expect("build_request_body failed");
assert_eq!(body["model"], "gpt-4");
let messages = body["messages"].as_array().expect("messages must be array");
assert_eq!(messages.len(), 2);
assert_eq!(messages[0]["role"], "system");
assert_eq!(messages[0]["content"], "You are a helpful assistant");
assert_eq!(messages[1]["role"], "user");
assert_eq!(messages[1]["content"], "Hello world");
assert_eq!(body["max_tokens"], 512);
}
#[test]
fn test_anthropic_format_body() {
let client = CustomEndpointClient::new(anthropic_config());
let body = client
.build_request_body("What is a blockchain?", Some("Expert assistant"))
.expect("build_request_body failed");
assert_eq!(body["model"], "claude-3-opus-20240229");
assert_eq!(body["system"], "Expert assistant");
let messages = body["messages"].as_array().expect("messages must be array");
assert_eq!(messages.len(), 1);
assert_eq!(messages[0]["role"], "user");
assert_eq!(messages[0]["content"], "What is a blockchain?");
assert!(body.get("max_tokens").is_some());
}
#[test]
fn test_custom_format_body() {
let template = r#"{"input": "{prompt}", "task": "summarise"}"#.to_string();
let mut cfg = openai_config();
cfg.request_format = EndpointRequestFormat::CustomJson(template);
let client = CustomEndpointClient::new(cfg);
let body = client
.build_request_body("Summarise this document", None)
.expect("build_request_body failed");
assert_eq!(body["input"], "Summarise this document");
assert_eq!(body["task"], "summarise");
}
#[test]
fn test_registry_register_and_get() {
let mut registry = CustomEndpointRegistry::new();
let token_id = Uuid::new_v4();
let cfg = openai_config();
let url = cfg.endpoint_url.clone();
let returned_id = registry.register(token_id, cfg);
let endpoints = registry.get_endpoints(&token_id);
assert_eq!(endpoints.len(), 1);
assert_eq!(endpoints[0].endpoint_url, url);
assert_eq!(endpoints[0].endpoint_id, returned_id);
}
#[test]
fn test_registry_remove() {
let mut registry = CustomEndpointRegistry::new();
let token_id = Uuid::new_v4();
let cfg = openai_config();
let url = cfg.endpoint_url.clone();
registry.register(token_id, cfg);
assert_eq!(registry.endpoint_count(&token_id), 1);
let removed = registry.remove(&token_id, &url);
assert!(removed);
assert_eq!(registry.endpoint_count(&token_id), 0);
let removed_again = registry.remove(&token_id, &url);
assert!(!removed_again);
}
#[test]
fn test_registry_endpoint_count() {
let mut registry = CustomEndpointRegistry::new();
let token_id = Uuid::new_v4();
assert_eq!(registry.endpoint_count(&token_id), 0);
let mut cfg1 = openai_config();
cfg1.endpoint_url = "https://endpoint-one.example.com/chat".to_string();
registry.register(token_id, cfg1);
assert_eq!(registry.endpoint_count(&token_id), 1);
let mut cfg2 = openai_config();
cfg2.endpoint_url = "https://endpoint-two.example.com/chat".to_string();
registry.register(token_id, cfg2);
assert_eq!(registry.endpoint_count(&token_id), 2);
let other_id = Uuid::new_v4();
assert_eq!(registry.endpoint_count(&other_id), 0);
}
#[test]
fn test_personalization_build_system_prompt() {
let token_id = Uuid::new_v4();
let mut personalization = IssuerPersonalization::new(token_id);
personalization.context_documents = vec![
"We are a DeFi protocol focused on lending.".to_string(),
"Our token is KAC.".to_string(),
];
personalization.preferred_response_style = ResponseStyle::Technical;
let prompt = personalization.build_system_prompt("You are a helpful AI assistant.");
assert!(prompt.contains("DeFi protocol"));
assert!(prompt.contains("KAC"));
assert!(prompt.contains("You are a helpful AI assistant."));
assert!(prompt.contains("technical terminology"));
}
#[test]
fn test_personalization_system_prompt_override() {
let token_id = Uuid::new_v4();
let mut personalization = IssuerPersonalization::new(token_id);
personalization.system_prompt_override =
Some("You are an expert in DeFi lending protocols.".to_string());
personalization.preferred_response_style = ResponseStyle::Concise;
let prompt = personalization.build_system_prompt("You are a helpful AI assistant.");
assert!(prompt.contains("DeFi lending protocols"));
assert!(!prompt.contains("helpful AI assistant"));
assert!(prompt.contains("brief and to the point"));
}
}