use crate::client::rest_client::RestClient;
use crate::client::socket_client::SocketClient;
use crate::types::{RunAgentError, RunAgentResult};
use crate::utils::serializer::CoreSerializer;
use futures::Stream;
use serde_json::Value;
use std::collections::HashMap;
use std::pin::Pin;
#[cfg(feature = "db")]
use crate::db::DatabaseService;
pub struct RunAgentClient {
agent_id: String,
entrypoint_tag: String,
local: bool,
rest_client: RestClient,
socket_client: SocketClient,
serializer: CoreSerializer,
agent_architecture: Option<Value>,
extra_params: Option<HashMap<String, Value>>,
user_id: Option<String>,
persistent_memory: bool,
#[cfg(feature = "db")]
#[allow(dead_code)] db_service: Option<DatabaseService>,
}
#[derive(Debug, Clone)]
pub struct RunAgentClientConfig {
pub agent_id: String,
pub entrypoint_tag: String,
pub local: Option<bool>,
pub host: Option<String>,
pub port: Option<u16>,
pub api_key: Option<String>,
pub base_url: Option<String>,
pub extra_params: Option<HashMap<String, Value>>,
pub enable_registry: Option<bool>,
pub user_id: Option<String>,
pub persistent_memory: Option<bool>,
}
#[allow(clippy::derivable_impls)]
impl Default for RunAgentClientConfig {
fn default() -> Self {
Self {
agent_id: String::new(), entrypoint_tag: String::new(), local: None,
host: None,
port: None,
api_key: None,
base_url: None,
extra_params: None,
enable_registry: None,
user_id: None,
persistent_memory: None,
}
}
}
impl RunAgentClientConfig {
pub fn new(agent_id: impl Into<String>, entrypoint_tag: impl Into<String>) -> Self {
Self {
agent_id: agent_id.into(),
entrypoint_tag: entrypoint_tag.into(),
local: None,
host: None,
port: None,
api_key: None,
base_url: None,
extra_params: None,
enable_registry: None,
user_id: None,
persistent_memory: None,
}
}
pub fn with_local(mut self, local: bool) -> Self {
self.local = Some(local);
self
}
pub fn with_address(mut self, host: impl Into<String>, port: u16) -> Self {
self.host = Some(host.into());
self.port = Some(port);
self
}
pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = Some(base_url.into());
self
}
pub fn with_extra_params(mut self, extra_params: HashMap<String, Value>) -> Self {
self.extra_params = Some(extra_params);
self
}
pub fn with_enable_registry(mut self, enable: bool) -> Self {
self.enable_registry = Some(enable);
self
}
pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
self.user_id = Some(user_id.into());
self
}
pub fn with_persistent_memory(mut self, persistent: bool) -> Self {
self.persistent_memory = Some(persistent);
self
}
}
impl RunAgentClient {
pub async fn new(config: RunAgentClientConfig) -> RunAgentResult<Self> {
use crate::constants::{DEFAULT_BASE_URL, ENV_RUNAGENT_API_KEY, ENV_RUNAGENT_BASE_URL};
let local = config.local.unwrap_or(false);
let enable_registry = config.enable_registry.unwrap_or(local);
let (host, port) = if local {
if let (Some(h), Some(p)) = (&config.host, &config.port) {
(Some(h.clone()), Some(*p))
} else if enable_registry {
#[cfg(feature = "db")]
{
let db_service = DatabaseService::new(None).await?;
if let Some(agent_info) = db_service.get_agent(&config.agent_id).await? {
tracing::info!(
"🔍 Found agent in database: {}:{}",
agent_info.host,
agent_info.port
);
(Some(agent_info.host), Some(agent_info.port as u16))
} else {
(config.host.clone(), config.port)
}
}
#[cfg(not(feature = "db"))]
{
(config.host.clone(), config.port)
}
} else {
(config.host.clone(), config.port)
}
} else {
(None, None)
};
let api_key = config
.api_key
.or_else(|| std::env::var(ENV_RUNAGENT_API_KEY).ok());
let base_url = config
.base_url
.or_else(|| std::env::var(ENV_RUNAGENT_BASE_URL).ok())
.unwrap_or_else(|| DEFAULT_BASE_URL.to_string());
if !local {
tracing::info!("🌐 Connecting to remote agent at {}", base_url);
if api_key.is_some() {
tracing::debug!("🔑 API key provided");
} else {
tracing::warn!("⚠️ No API key provided - using default limits");
}
}
let serializer = CoreSerializer::new(10.0)?;
#[cfg(feature = "db")]
let db_service: Option<DatabaseService> = None;
#[cfg(not(feature = "db"))]
let db_service: Option<DatabaseService> = None;
let (rest_client, socket_client) = if local {
let host = host.ok_or_else(|| {
RunAgentError::validation(
"Host is required for local clients. Provide host/port in config or enable registry for database lookup.",
)
})?;
let port = port.ok_or_else(|| {
RunAgentError::validation(
"Port is required for local clients. Provide host/port in config or enable registry for database lookup.",
)
})?;
tracing::info!("🔌 Using address: {}:{}", host, port);
let agent_base_url = format!("http://{}:{}", host, port);
let agent_socket_url = format!("ws://{}:{}", host, port);
let rest_client = RestClient::new(&agent_base_url, None, Some("/api/v1"))?;
let socket_client = SocketClient::new(&agent_socket_url, None, Some("/api/v1"))?;
(rest_client, socket_client)
} else {
Self::create_remote_clients(Some(&base_url), api_key)?
};
let mut client = Self {
agent_id: config.agent_id,
entrypoint_tag: config.entrypoint_tag,
local,
rest_client,
socket_client,
serializer,
agent_architecture: None,
extra_params: config.extra_params,
user_id: config.user_id,
persistent_memory: config.persistent_memory.unwrap_or(false),
#[cfg(feature = "db")]
db_service,
};
client.initialize_architecture().await?;
Ok(client)
}
async fn initialize_architecture(&mut self) -> RunAgentResult<()> {
let architecture = self.get_agent_architecture_internal().await?;
self.agent_architecture = Some(architecture);
self.validate_entrypoint()?;
Ok(())
}
async fn get_agent_architecture_internal(&self) -> RunAgentResult<Value> {
self.rest_client
.get_agent_architecture(&self.agent_id)
.await
}
fn validate_entrypoint(&self) -> RunAgentResult<()> {
if let Some(ref architecture) = self.agent_architecture {
if let Some(entrypoints) = architecture.get("entrypoints").and_then(|e| e.as_array()) {
let found = entrypoints.iter().any(|ep| {
ep.get("tag")
.and_then(|t| t.as_str())
.map(|t| t == self.entrypoint_tag)
.unwrap_or(false)
});
if !found {
let available: Vec<String> = entrypoints
.iter()
.filter_map(|ep| ep.get("tag").and_then(|t| t.as_str()))
.map(|s| s.to_string())
.collect();
tracing::error!(
"Entrypoint `{}` not found for agent {}. Available: {:?}",
self.entrypoint_tag,
self.agent_id,
available
);
return Err(RunAgentError::validation(format!(
"Entrypoint `{}` not found in agent {}",
self.entrypoint_tag, self.agent_id
)));
}
}
}
Ok(())
}
pub async fn run(&self, input_kwargs: &[(&str, Value)]) -> RunAgentResult<Value> {
self.run_with_args(&[], input_kwargs).await
}
pub async fn run_with_args(
&self,
input_args: &[Value],
input_kwargs: &[(&str, Value)],
) -> RunAgentResult<Value> {
if self.entrypoint_tag.ends_with("_stream") {
return Err(RunAgentError::validation(
"Use run_stream for streaming entrypoints".to_string(),
));
}
let input_kwargs_map: HashMap<String, Value> = input_kwargs
.iter()
.map(|(k, v)| (k.to_string(), v.clone()))
.collect();
let response = self
.rest_client
.run_agent(
&self.agent_id,
&self.entrypoint_tag,
input_args,
&input_kwargs_map,
self.user_id.as_deref(),
self.persistent_memory,
)
.await?;
if response
.get("success")
.and_then(|s| s.as_bool())
.unwrap_or(false)
{
let mut payload: Option<Value> = None;
if let Some(data) = response.get("data") {
if data.as_str().is_some() {
if let Some(data_str) = data.as_str() {
let lower_str = data_str.to_lowercase();
if lower_str.contains("generator object")
|| lower_str.contains("<generator")
{
let streaming_tag = format!("{}_stream", self.entrypoint_tag);
return Err(RunAgentError::validation(format!(
"Agent returned a generator object instead of content. This entrypoint appears to be a streaming function.\n\
Try using the streaming endpoint: `{}`\n\
Or use `run_stream()` method instead of `run()`.",
streaming_tag
)));
}
}
let prepared = self.serializer.prepare_for_deserialization(data.clone());
payload = Some(prepared);
}
else if let Some(result_data) = data.get("result_data") {
if let Some(output_data) = result_data.get("data") {
if let Some(output_str) = output_data.as_str() {
let lower_str = output_str.to_lowercase();
if lower_str.contains("generator object")
|| lower_str.contains("<generator")
{
let streaming_tag = format!("{}_stream", self.entrypoint_tag);
return Err(RunAgentError::validation(format!(
"Agent returned a generator object instead of content. This entrypoint appears to be a streaming function.\n\
Try using the streaming endpoint: `{}`\n\
Or use `run_stream()` method instead of `run()`.",
streaming_tag
)));
}
}
payload = Some(output_data.clone());
}
}
else if data.is_object() {
payload = Some(data.clone());
}
}
else if let Some(output_data) = response.get("output_data") {
if let Some(output_str) = output_data.as_str() {
let lower_str = output_str.to_lowercase();
if lower_str.contains("generator object") || lower_str.contains("<generator") {
let streaming_tag = format!("{}_stream", self.entrypoint_tag);
return Err(RunAgentError::validation(format!(
"Agent returned a generator object instead of content. This entrypoint appears to be a streaming function.\n\
Try using the streaming endpoint: `{}`\n\
Or use `run_stream()` method instead of `run()`.",
streaming_tag
)));
}
}
payload = Some(output_data.clone());
}
if let Some(payload_val) = payload {
if let Some(content_str) = payload_val.as_str() {
let lower_str = content_str.to_lowercase();
if lower_str.contains("generator object") || lower_str.contains("<generator") {
let streaming_tag = format!("{}_stream", self.entrypoint_tag);
return Err(RunAgentError::validation(format!(
"Agent returned a generator object instead of content. This entrypoint appears to be a streaming function.\n\
Try using the streaming endpoint: `{}`\n\
Or use `run_stream()` method instead of `run()`.",
streaming_tag
)));
}
}
let deserialized = self.serializer.deserialize_object(payload_val)?;
return Ok(deserialized);
}
Ok(Value::Null)
} else {
if let Some(error_info) = response.get("error") {
if let Some(error_obj) = error_info.as_object() {
if let (Some(message), Some(code)) = (
error_obj.get("message").and_then(|m| m.as_str()),
error_obj.get("code").and_then(|c| c.as_str()),
) {
return Err(RunAgentError::server(format!("[{}] {}", code, message)));
}
}
if let Some(error_msg) = error_info.as_str() {
return Err(RunAgentError::server(error_msg));
}
}
Err(RunAgentError::server("Unknown error"))
}
}
pub async fn run_stream(
&self,
input_kwargs: &[(&str, Value)],
) -> RunAgentResult<Pin<Box<dyn Stream<Item = RunAgentResult<Value>> + Send>>> {
self.run_stream_with_args(&[], input_kwargs).await
}
pub async fn run_stream_with_args(
&self,
input_args: &[Value],
input_kwargs: &[(&str, Value)],
) -> RunAgentResult<Pin<Box<dyn Stream<Item = RunAgentResult<Value>> + Send>>> {
if !self.entrypoint_tag.ends_with("_stream") {
return Err(RunAgentError::validation(
"Use run() for non-stream entrypoints".to_string(),
));
}
let input_kwargs_map: HashMap<String, Value> = input_kwargs
.iter()
.map(|(k, v)| (k.to_string(), v.clone()))
.collect();
self.socket_client
.run_stream(
&self.agent_id,
&self.entrypoint_tag,
input_args,
&input_kwargs_map,
self.user_id.as_deref(),
self.persistent_memory,
)
.await
}
pub async fn get_agent_architecture(&self) -> RunAgentResult<Value> {
self.rest_client
.get_agent_architecture(&self.agent_id)
.await
}
pub async fn health_check(&self) -> RunAgentResult<bool> {
match self.rest_client.health_check().await {
Ok(_) => Ok(true),
Err(_) => Ok(false),
}
}
pub fn agent_id(&self) -> &str {
&self.agent_id
}
pub fn entrypoint_tag(&self) -> &str {
&self.entrypoint_tag
}
pub fn extra_params(&self) -> Option<&HashMap<String, Value>> {
self.extra_params.as_ref()
}
pub fn user_id(&self) -> Option<&str> {
self.user_id.as_deref()
}
pub fn persistent_memory(&self) -> bool {
self.persistent_memory
}
pub fn is_local(&self) -> bool {
self.local
}
}
impl RunAgentClient {
fn create_remote_clients(
base_url_override: Option<&str>,
api_key_override: Option<String>,
) -> RunAgentResult<(RestClient, SocketClient)> {
if let Some(base_url) = base_url_override {
let rest_client = RestClient::new(base_url, api_key_override.clone(), Some("/api/v1"))?;
let socket_base = if base_url.starts_with("https://") {
base_url.replace("https://", "wss://")
} else if base_url.starts_with("http://") {
base_url.replace("http://", "ws://")
} else {
format!("wss://{}", base_url)
};
let socket_client = SocketClient::new(&socket_base, api_key_override, Some("/api/v1"))?;
Ok((rest_client, socket_client))
} else {
let rest_client = RestClient::default()?;
let socket_client = SocketClient::default()?;
Ok((rest_client, socket_client))
}
}
}