use crate::error::{AgentError, Result as AgentResult};
use crate::session::core::AgentSession;
use std::future::{Future, IntoFuture};
use std::pin::Pin;
use std::sync::atomic::Ordering;
use turboclaude_protocol::{Message, QueryRequest, QueryResponse, RequestId, ToolDefinition};
impl AgentSession {
pub async fn query(&self, request: QueryRequest) -> AgentResult<QueryResponse> {
if request.query.is_empty() {
return Err(AgentError::Config("Query cannot be empty".into()));
}
if request.max_tokens == 0 {
return Err(AgentError::Config("max_tokens must be > 0".into()));
}
self.ensure_connected().await?;
let request_id = RequestId::new();
let count = self.active_queries.fetch_add(1, Ordering::Relaxed);
if count as usize >= self.config.max_concurrent_queries {
self.active_queries.fetch_sub(1, Ordering::Relaxed);
return Err(AgentError::Protocol(format!(
"Too many concurrent queries (max: {})",
self.config.max_concurrent_queries
)));
}
let router_lock = self.router.lock().await;
let router = match router_lock.as_ref() {
Some(r) => r,
None => {
self.active_queries.fetch_sub(1, Ordering::Relaxed);
return Err(AgentError::Transport("Router not initialized".into()));
}
};
let response = router.send_query(request_id, request).await;
self.active_queries.fetch_sub(1, Ordering::Relaxed);
response
}
pub fn query_str(&self, query: impl Into<String>) -> QueryBuilder<'_> {
QueryBuilder::new(self, query.into())
}
pub async fn receive_messages(
&self,
) -> impl futures::Stream<Item = Result<crate::message_parser::ParsedMessage, AgentError>> + '_
{
use crate::message_parser::parse_message;
use futures::stream;
use std::sync::Arc;
let transport = Arc::clone(&self.transport);
stream::unfold(transport, |transport| async move {
match transport.recv_message().await {
Ok(Some(json_value)) => {
match parse_message(json_value) {
Ok(parsed) => Some((Ok(parsed), transport)),
Err(e) => Some((
Err(AgentError::Protocol(format!("Message parse error: {}", e))),
transport,
)),
}
}
Ok(None) => {
None
}
Err(e) => Some((
Err(AgentError::Transport(format!("Transport error: {}", e))),
transport,
)),
}
})
}
}
pub struct QueryBuilder<'a> {
session: &'a AgentSession,
query: String,
system_prompt: Option<String>,
model: Option<String>,
max_tokens: Option<u32>,
tools: Option<Vec<ToolDefinition>>,
messages: Option<Vec<Message>>,
}
impl<'a> QueryBuilder<'a> {
pub(crate) fn new(session: &'a AgentSession, query: String) -> Self {
Self {
session,
query,
system_prompt: None,
model: None,
max_tokens: None,
tools: None,
messages: None,
}
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn tools(mut self, tools: Vec<ToolDefinition>) -> Self {
self.tools = Some(tools);
self
}
pub fn messages(mut self, messages: Vec<Message>) -> Self {
self.messages = Some(messages);
self
}
pub async fn send(self) -> AgentResult<QueryResponse> {
let state = self.session.state.lock().await;
let default_model = state.current_model.clone();
drop(state);
#[cfg(feature = "skills")]
let system_prompt = {
let manager = self.session.skill_manager.read().await;
if let Some(m) = manager.as_ref() {
let skill_context = m.build_context().await;
if !skill_context.is_empty() {
let current_prompt = self.system_prompt.unwrap_or_default();
Some(format!("{}{}", current_prompt, skill_context))
} else {
self.system_prompt
}
} else {
self.system_prompt
}
};
#[cfg(not(feature = "skills"))]
let system_prompt = self.system_prompt;
#[cfg(feature = "skills")]
{
let manager = self.session.skill_manager.read().await;
if let Some(m) = manager.as_ref() {
m.increment_usage().await;
}
}
let request = QueryRequest {
query: self.query,
system_prompt,
model: self.model.unwrap_or(default_model),
max_tokens: self.max_tokens.unwrap_or(4096),
tools: self.tools.unwrap_or_default(),
messages: self.messages.unwrap_or_default(),
};
self.session.query(request).await
}
}
impl<'a> IntoFuture for QueryBuilder<'a> {
type Output = AgentResult<QueryResponse>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(self.send())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::AtomicU32;
#[test]
fn test_concurrent_query_tracking() {
let counter = Arc::new(AtomicU32::new(0));
let c1 = Arc::clone(&counter);
let c2 = Arc::clone(&counter);
let c3 = Arc::clone(&counter);
let v1 = c1.fetch_add(1, Ordering::Relaxed);
let v2 = c2.fetch_add(1, Ordering::Relaxed);
let v3 = c3.fetch_add(1, Ordering::Relaxed);
assert_eq!(v1, 0);
assert_eq!(v2, 1);
assert_eq!(v3, 2);
c1.fetch_sub(1, Ordering::Relaxed);
c2.fetch_sub(1, Ordering::Relaxed);
c3.fetch_sub(1, Ordering::Relaxed);
assert_eq!(counter.load(Ordering::Relaxed), 0);
}
}