use crate::client::RunAgentClient as AsyncRunAgentClient;
use crate::types::{RunAgentError, RunAgentResult};
use futures::Stream;
use serde_json::Value;
use std::collections::HashMap;
use std::pin::Pin;
use tokio::runtime::Runtime;
pub use crate::client::RunAgentClientConfig;
pub struct RunAgentClient {
inner: AsyncRunAgentClient,
runtime: Runtime,
}
impl RunAgentClient {
pub fn new(config: RunAgentClientConfig) -> RunAgentResult<Self> {
let runtime = Runtime::new()
.map_err(|e| RunAgentError::connection(format!("Failed to create runtime: {}", e)))?;
let inner = runtime.block_on(AsyncRunAgentClient::new(config))?;
Ok(Self { inner, runtime })
}
pub fn run(&self, input_kwargs: &[(&str, Value)]) -> RunAgentResult<Value> {
self.runtime.block_on(self.inner.run(input_kwargs))
}
pub fn run_with_args(
&self,
input_args: &[Value],
input_kwargs: &[(&str, Value)],
) -> RunAgentResult<Value> {
self.runtime
.block_on(self.inner.run_with_args(input_args, input_kwargs))
}
pub fn run_stream(&self, input_kwargs: &[(&str, Value)]) -> RunAgentResult<BlockingStream> {
let stream = self.runtime.block_on(self.inner.run_stream(input_kwargs))?;
Ok(BlockingStream::new(stream))
}
pub fn run_stream_with_args(
&self,
input_args: &[Value],
input_kwargs: &[(&str, Value)],
) -> RunAgentResult<BlockingStream> {
let stream = self
.runtime
.block_on(self.inner.run_stream_with_args(input_args, input_kwargs))?;
Ok(BlockingStream::new(stream))
}
pub fn get_agent_architecture(&self) -> RunAgentResult<Value> {
self.runtime.block_on(self.inner.get_agent_architecture())
}
pub fn health_check(&self) -> RunAgentResult<bool> {
self.runtime.block_on(self.inner.health_check())
}
pub fn agent_id(&self) -> &str {
self.inner.agent_id()
}
pub fn entrypoint_tag(&self) -> &str {
self.inner.entrypoint_tag()
}
pub fn extra_params(&self) -> Option<&HashMap<String, Value>> {
self.inner.extra_params()
}
pub fn user_id(&self) -> Option<&str> {
self.inner.user_id()
}
pub fn persistent_memory(&self) -> bool {
self.inner.persistent_memory()
}
pub fn is_local(&self) -> bool {
self.inner.is_local()
}
}
pub struct BlockingStream {
receiver: std::sync::mpsc::Receiver<RunAgentResult<Value>>,
_handle: std::thread::JoinHandle<()>, }
impl BlockingStream {
pub(crate) fn new(
mut stream: Pin<Box<dyn Stream<Item = RunAgentResult<Value>> + Send>>,
) -> Self {
use futures::StreamExt;
use std::sync::mpsc;
use std::thread;
let (tx, rx) = mpsc::channel();
let handle = thread::spawn(move || {
let rt = Runtime::new().expect("Failed to create runtime");
rt.block_on(async move {
while let Some(item) = stream.next().await {
if tx.send(item).is_err() {
break;
}
}
});
});
Self {
receiver: rx,
_handle: handle,
}
}
}
impl Iterator for BlockingStream {
type Item = RunAgentResult<Value>;
fn next(&mut self) -> Option<Self::Item> {
self.receiver.recv().ok()
}
}