use async_trait::async_trait;
use reqwest::Client;
use rust_decimal::Decimal;
use rust_decimal_macros::dec;
use secrecy::ExposeSecret;
use serde::{Deserialize, Serialize};
use crate::config::NearAiConfig;
use crate::error::LlmError;
use crate::llm::provider::{
ChatMessage, CompletionRequest, CompletionResponse, FinishReason, LlmProvider, ModelMetadata,
Role, ToolCall, ToolCompletionRequest, ToolCompletionResponse,
};
use crate::llm::retry::{is_retryable_status, retry_backoff_delay};
pub struct NearAiChatProvider {
client: Client,
config: NearAiConfig,
active_model: std::sync::RwLock<String>,
}
impl NearAiChatProvider {
pub fn new(config: NearAiConfig) -> Result<Self, LlmError> {
if config.api_key.is_none() {
return Err(LlmError::AuthFailed {
provider: "nearai_chat".to_string(),
});
}
let client = Client::builder()
.timeout(std::time::Duration::from_secs(120))
.build()
.unwrap_or_else(|_| Client::new());
let active_model = std::sync::RwLock::new(config.model.clone());
Ok(Self {
client,
config,
active_model,
})
}
fn api_url(&self, path: &str) -> String {
format!(
"{}/v1/{}",
self.config.base_url,
path.trim_start_matches('/')
)
}
fn api_key(&self) -> String {
self.config
.api_key
.as_ref()
.map(|k| k.expose_secret().to_string())
.unwrap_or_default()
}
async fn send_request<T: Serialize, R: for<'de> Deserialize<'de>>(
&self,
body: &T,
) -> Result<R, LlmError> {
let url = self.api_url("chat/completions");
let max_retries = self.config.max_retries;
for attempt in 0..=max_retries {
tracing::debug!(
"Sending request to NEAR AI Chat: {} (attempt {})",
url,
attempt + 1,
);
if tracing::enabled!(tracing::Level::DEBUG)
&& let Ok(json) = serde_json::to_string(body)
{
tracing::debug!("NEAR AI Chat request body: {}", json);
}
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key()))
.header("Content-Type", "application/json")
.json(body)
.send()
.await;
let response = match response {
Ok(r) => r,
Err(e) => {
tracing::error!("NEAR AI Chat request failed: {}", e);
if attempt < max_retries {
let delay = retry_backoff_delay(attempt);
tracing::warn!(
"NEAR AI Chat request error (attempt {}/{}), retrying in {:?}: {}",
attempt + 1,
max_retries + 1,
delay,
e,
);
tokio::time::sleep(delay).await;
continue;
}
return Err(LlmError::RequestFailed {
provider: "nearai_chat".to_string(),
reason: e.to_string(),
});
}
};
let status = response.status();
let response_text = response.text().await.unwrap_or_default();
tracing::debug!("NEAR AI Chat response status: {}", status);
tracing::debug!("NEAR AI Chat response body: {}", response_text);
if !status.is_success() {
let status_code = status.as_u16();
if status_code == 401 {
return Err(LlmError::AuthFailed {
provider: "nearai_chat".to_string(),
});
}
if is_retryable_status(status_code) && attempt < max_retries {
let delay = retry_backoff_delay(attempt);
tracing::warn!(
"NEAR AI Chat returned HTTP {} (attempt {}/{}), retrying in {:?}",
status_code,
attempt + 1,
max_retries + 1,
delay,
);
tokio::time::sleep(delay).await;
continue;
}
if status_code == 429 {
return Err(LlmError::RateLimited {
provider: "nearai_chat".to_string(),
retry_after: None,
});
}
return Err(LlmError::RequestFailed {
provider: "nearai_chat".to_string(),
reason: format!("HTTP {}: {}", status, response_text),
});
}
return serde_json::from_str(&response_text).map_err(|e| LlmError::InvalidResponse {
provider: "nearai_chat".to_string(),
reason: format!("JSON parse error: {}. Raw: {}", e, response_text),
});
}
Err(LlmError::RequestFailed {
provider: "nearai_chat".to_string(),
reason: "retry loop exited unexpectedly".to_string(),
})
}
async fn fetch_models(&self) -> Result<Vec<ApiModelEntry>, LlmError> {
let url = self.api_url("models");
let response = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.api_key()))
.send()
.await
.map_err(|e| LlmError::RequestFailed {
provider: "nearai_chat".to_string(),
reason: format!("Failed to fetch models: {}", e),
})?;
let status = response.status();
let response_text = response.text().await.unwrap_or_default();
if !status.is_success() {
return Err(LlmError::RequestFailed {
provider: "nearai_chat".to_string(),
reason: format!("HTTP {}: {}", status, response_text),
});
}
#[derive(Deserialize)]
struct ModelsResponse {
data: Vec<ApiModelEntry>,
}
let resp: ModelsResponse =
serde_json::from_str(&response_text).map_err(|e| LlmError::InvalidResponse {
provider: "nearai_chat".to_string(),
reason: format!("JSON parse error: {}", e),
})?;
Ok(resp.data)
}
}
#[derive(Debug, Deserialize)]
struct ApiModelEntry {
id: String,
#[serde(default)]
context_length: Option<u32>,
}
#[async_trait]
impl LlmProvider for NearAiChatProvider {
async fn complete(&self, req: CompletionRequest) -> Result<CompletionResponse, LlmError> {
let messages: Vec<ChatCompletionMessage> =
req.messages.into_iter().map(|m| m.into()).collect();
let request = ChatCompletionRequest {
model: self.active_model_name(),
messages,
temperature: req.temperature,
max_tokens: req.max_tokens,
tools: None,
tool_choice: None,
};
let response: ChatCompletionResponse = self.send_request(&request).await?;
let choice =
response
.choices
.into_iter()
.next()
.ok_or_else(|| LlmError::InvalidResponse {
provider: "nearai_chat".to_string(),
reason: "No choices in response".to_string(),
})?;
let content = choice.message.content.unwrap_or_default();
let finish_reason = match choice.finish_reason.as_deref() {
Some("stop") => FinishReason::Stop,
Some("length") => FinishReason::Length,
Some("tool_calls") => FinishReason::ToolUse,
Some("content_filter") => FinishReason::ContentFilter,
_ => FinishReason::Unknown,
};
Ok(CompletionResponse {
content,
finish_reason,
input_tokens: response.usage.prompt_tokens,
output_tokens: response.usage.completion_tokens,
response_id: None,
})
}
async fn complete_with_tools(
&self,
req: ToolCompletionRequest,
) -> Result<ToolCompletionResponse, LlmError> {
let messages: Vec<ChatCompletionMessage> =
req.messages.into_iter().map(|m| m.into()).collect();
let messages = flatten_tool_messages(messages);
let tools: Vec<ChatCompletionTool> = req
.tools
.into_iter()
.map(|t| ChatCompletionTool {
tool_type: "function".to_string(),
function: ChatCompletionFunction {
name: t.name,
description: Some(t.description),
parameters: Some(t.parameters),
},
})
.collect();
let request = ChatCompletionRequest {
model: self.active_model_name(),
messages,
temperature: req.temperature,
max_tokens: req.max_tokens,
tools: if tools.is_empty() { None } else { Some(tools) },
tool_choice: req.tool_choice,
};
let response: ChatCompletionResponse = self.send_request(&request).await?;
let choice =
response
.choices
.into_iter()
.next()
.ok_or_else(|| LlmError::InvalidResponse {
provider: "nearai_chat".to_string(),
reason: "No choices in response".to_string(),
})?;
let content = choice.message.content;
let tool_calls: Vec<ToolCall> = choice
.message
.tool_calls
.unwrap_or_default()
.into_iter()
.map(|tc| {
let arguments = serde_json::from_str(&tc.function.arguments)
.unwrap_or(serde_json::Value::Object(Default::default()));
ToolCall {
id: tc.id,
name: tc.function.name,
arguments,
}
})
.collect();
let finish_reason = match choice.finish_reason.as_deref() {
Some("stop") => FinishReason::Stop,
Some("length") => FinishReason::Length,
Some("tool_calls") => FinishReason::ToolUse,
Some("content_filter") => FinishReason::ContentFilter,
_ => {
if !tool_calls.is_empty() {
FinishReason::ToolUse
} else {
FinishReason::Unknown
}
}
};
Ok(ToolCompletionResponse {
content,
tool_calls,
finish_reason,
input_tokens: response.usage.prompt_tokens,
output_tokens: response.usage.completion_tokens,
response_id: None,
})
}
fn model_name(&self) -> &str {
&self.config.model
}
fn cost_per_token(&self) -> (Decimal, Decimal) {
(dec!(0.000003), dec!(0.000015))
}
async fn list_models(&self) -> Result<Vec<String>, LlmError> {
let models = self.fetch_models().await?;
Ok(models.into_iter().map(|m| m.id).collect())
}
async fn model_metadata(&self) -> Result<ModelMetadata, LlmError> {
let active = self.active_model_name();
let models = self.fetch_models().await?;
let current = models.iter().find(|m| m.id == active);
Ok(ModelMetadata {
id: active,
context_length: current.and_then(|m| m.context_length),
})
}
fn active_model_name(&self) -> String {
self.active_model
.read()
.expect("active_model lock poisoned")
.clone()
}
fn set_model(&self, model: &str) -> Result<(), crate::error::LlmError> {
let mut guard = self
.active_model
.write()
.expect("active_model lock poisoned");
*guard = model.to_string();
Ok(())
}
}
#[derive(Debug, Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<ChatCompletionMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<ChatCompletionTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ChatCompletionMessage {
role: String,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ChatCompletionToolCall>>,
}
fn flatten_tool_messages(messages: Vec<ChatCompletionMessage>) -> Vec<ChatCompletionMessage> {
let has_tool_msgs = messages.iter().any(|m| m.role == "tool");
if !has_tool_msgs {
return messages;
}
tracing::debug!("Flattening tool messages for NEAR AI compatibility");
messages
.into_iter()
.map(|msg| {
if let (true, Some(calls)) = (msg.role == "assistant", &msg.tool_calls) {
let mut parts: Vec<String> = Vec::new();
if let Some(ref text) = msg.content
&& !text.is_empty()
{
parts.push(text.clone());
}
for tc in calls {
parts.push(format!(
"[Called tool `{}` with arguments: {}]",
tc.function.name, tc.function.arguments
));
}
ChatCompletionMessage {
role: "assistant".to_string(),
content: Some(parts.join("\n")),
tool_call_id: None,
name: None,
tool_calls: None,
}
} else if msg.role == "tool" {
let tool_name = msg.name.as_deref().unwrap_or("unknown");
let result = msg.content.as_deref().unwrap_or("");
ChatCompletionMessage {
role: "user".to_string(),
content: Some(format!("[Tool `{}` returned: {}]", tool_name, result)),
tool_call_id: None,
name: None,
tool_calls: None,
}
} else {
msg
}
})
.collect()
}
impl From<ChatMessage> for ChatCompletionMessage {
fn from(msg: ChatMessage) -> Self {
let role = match msg.role {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => "tool",
};
let tool_calls = msg.tool_calls.map(|calls| {
calls
.into_iter()
.map(|tc| ChatCompletionToolCall {
id: tc.id,
call_type: "function".to_string(),
function: ChatCompletionToolCallFunction {
name: tc.name,
arguments: tc.arguments.to_string(),
},
})
.collect()
});
let content = if role == "assistant" && tool_calls.is_some() && msg.content.is_empty() {
None
} else {
Some(msg.content)
};
Self {
role: role.to_string(),
content,
tool_call_id: msg.tool_call_id,
name: msg.name,
tool_calls,
}
}
}
#[derive(Debug, Serialize)]
struct ChatCompletionTool {
#[serde(rename = "type")]
tool_type: String,
function: ChatCompletionFunction,
}
#[derive(Debug, Serialize)]
struct ChatCompletionFunction {
name: String,
#[serde(skip_serializing_if = "Option::is_none")]
description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
parameters: Option<serde_json::Value>,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionResponse {
#[allow(dead_code)]
id: String,
choices: Vec<ChatCompletionChoice>,
usage: ChatCompletionUsage,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionChoice {
message: ChatCompletionResponseMessage,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionResponseMessage {
#[allow(dead_code)]
role: String,
content: Option<String>,
tool_calls: Option<Vec<ChatCompletionToolCall>>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ChatCompletionToolCall {
id: String,
#[serde(rename = "type")]
#[allow(dead_code)]
call_type: String,
function: ChatCompletionToolCallFunction,
}
#[derive(Debug, Serialize, Deserialize)]
struct ChatCompletionToolCallFunction {
name: String,
arguments: String,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionUsage {
prompt_tokens: u32,
completion_tokens: u32,
#[allow(dead_code)]
total_tokens: u32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_conversion() {
let msg = ChatMessage::user("Hello");
let chat_msg: ChatCompletionMessage = msg.into();
assert_eq!(chat_msg.role, "user");
assert_eq!(chat_msg.content, Some("Hello".to_string()));
}
#[test]
fn test_tool_message_conversion() {
let msg = ChatMessage::tool_result("call_123", "my_tool", "result");
let chat_msg: ChatCompletionMessage = msg.into();
assert_eq!(chat_msg.role, "tool");
assert_eq!(chat_msg.tool_call_id, Some("call_123".to_string()));
assert_eq!(chat_msg.name, Some("my_tool".to_string()));
}
#[test]
fn test_assistant_with_tool_calls_conversion() {
use crate::llm::ToolCall;
let tool_calls = vec![
ToolCall {
id: "call_1".to_string(),
name: "list_issues".to_string(),
arguments: serde_json::json!({"owner": "foo", "repo": "bar"}),
},
ToolCall {
id: "call_2".to_string(),
name: "search".to_string(),
arguments: serde_json::json!({"query": "test"}),
},
];
let msg = ChatMessage::assistant_with_tool_calls(None, tool_calls);
let chat_msg: ChatCompletionMessage = msg.into();
assert_eq!(chat_msg.role, "assistant");
let tc = chat_msg.tool_calls.expect("tool_calls present");
assert_eq!(tc.len(), 2);
assert_eq!(tc[0].id, "call_1");
assert_eq!(tc[0].function.name, "list_issues");
assert_eq!(tc[0].call_type, "function");
assert_eq!(tc[1].id, "call_2");
assert_eq!(tc[1].function.name, "search");
}
#[test]
fn test_assistant_without_tool_calls_has_none() {
let msg = ChatMessage::assistant("Hello");
let chat_msg: ChatCompletionMessage = msg.into();
assert!(chat_msg.tool_calls.is_none());
}
#[test]
fn test_tool_call_arguments_serialized_to_string() {
use crate::llm::ToolCall;
let tc = ToolCall {
id: "call_1".to_string(),
name: "test".to_string(),
arguments: serde_json::json!({"key": "value"}),
};
let msg = ChatMessage::assistant_with_tool_calls(None, vec![tc]);
let chat_msg: ChatCompletionMessage = msg.into();
let calls = chat_msg.tool_calls.unwrap();
let parsed: serde_json::Value =
serde_json::from_str(&calls[0].function.arguments).expect("valid JSON string");
assert_eq!(parsed["key"], "value");
}
#[test]
fn test_flatten_no_tool_messages_passthrough() {
let messages = vec![
ChatCompletionMessage {
role: "system".to_string(),
content: Some("You are helpful.".to_string()),
tool_call_id: None,
name: None,
tool_calls: None,
},
ChatCompletionMessage {
role: "user".to_string(),
content: Some("Hello".to_string()),
tool_call_id: None,
name: None,
tool_calls: None,
},
];
let result = flatten_tool_messages(messages);
assert_eq!(result.len(), 2);
assert_eq!(result[0].role, "system");
assert_eq!(result[1].role, "user");
}
#[test]
fn test_flatten_tool_call_and_result() {
let messages = vec![
ChatCompletionMessage {
role: "user".to_string(),
content: Some("test".to_string()),
tool_call_id: None,
name: None,
tool_calls: None,
},
ChatCompletionMessage {
role: "assistant".to_string(),
content: None,
tool_call_id: None,
name: None,
tool_calls: Some(vec![ChatCompletionToolCall {
id: "call_1".to_string(),
call_type: "function".to_string(),
function: ChatCompletionToolCallFunction {
name: "echo".to_string(),
arguments: r#"{"message":"hi"}"#.to_string(),
},
}]),
},
ChatCompletionMessage {
role: "tool".to_string(),
content: Some("hi".to_string()),
tool_call_id: Some("call_1".to_string()),
name: Some("echo".to_string()),
tool_calls: None,
},
];
let result = flatten_tool_messages(messages);
assert_eq!(result.len(), 3);
assert_eq!(result[1].role, "assistant");
assert!(result[1].tool_calls.is_none());
assert!(
result[1]
.content
.as_ref()
.unwrap()
.contains("[Called tool `echo`")
);
assert_eq!(result[2].role, "user");
assert!(result[2].tool_call_id.is_none());
assert!(
result[2]
.content
.as_ref()
.unwrap()
.contains("[Tool `echo` returned: hi]")
);
}
#[test]
fn test_flatten_preserves_assistant_text_with_tool_calls() {
let messages = vec![
ChatCompletionMessage {
role: "assistant".to_string(),
content: Some("Let me check that.".to_string()),
tool_call_id: None,
name: None,
tool_calls: Some(vec![ChatCompletionToolCall {
id: "call_1".to_string(),
call_type: "function".to_string(),
function: ChatCompletionToolCallFunction {
name: "search".to_string(),
arguments: r#"{"q":"test"}"#.to_string(),
},
}]),
},
ChatCompletionMessage {
role: "tool".to_string(),
content: Some("found it".to_string()),
tool_call_id: Some("call_1".to_string()),
name: Some("search".to_string()),
tool_calls: None,
},
];
let result = flatten_tool_messages(messages);
let text = result[0].content.as_ref().unwrap();
assert!(text.starts_with("Let me check that."));
assert!(text.contains("[Called tool `search`"));
}
}