use std::time::Duration;
use reqwest::Client;
use serde::de::DeserializeOwned;
use serde_json::json;
use tracing::debug;
use uuid::Uuid;
use crate::error::Error;
use crate::models::*;
const DEFAULT_BASE_URL: &str = "https://starflask.com/api";
#[derive(Debug, Clone)]
pub struct PollConfig {
pub timeout: Duration,
pub interval: Duration,
}
impl Default for PollConfig {
fn default() -> Self {
Self {
timeout: Duration::from_secs(120),
interval: Duration::from_secs(2),
}
}
}
#[derive(Clone)]
pub struct Starflask {
api_key: String,
base_url: String,
client: Client,
pub poll_config: PollConfig,
}
impl Starflask {
pub fn new(api_key: &str, base_url: Option<&str>) -> Result<Self, Error> {
let client = Client::builder()
.timeout(Duration::from_secs(90))
.build()?;
Ok(Self {
api_key: api_key.to_string(),
base_url: base_url
.unwrap_or(DEFAULT_BASE_URL)
.trim_end_matches('/')
.to_string(),
client,
poll_config: PollConfig::default(),
})
}
async fn request<T: DeserializeOwned>(
&self,
method: reqwest::Method,
path: &str,
body: Option<serde_json::Value>,
) -> Result<T, Error> {
let url = format!("{}{}", self.base_url, path);
let mut req = self
.client
.request(method, &url)
.bearer_auth(&self.api_key);
if let Some(body) = body {
req = req.json(&body);
}
let resp = req.send().await?;
if !resp.status().is_success() {
let status = resp.status().as_u16();
let body = resp.text().await.unwrap_or_default();
return Err(Error::Api { status, body });
}
resp.json().await.map_err(Error::Request)
}
async fn get<T: DeserializeOwned>(&self, path: &str) -> Result<T, Error> {
self.request(reqwest::Method::GET, path, None).await
}
async fn post<T: DeserializeOwned>(
&self,
path: &str,
body: serde_json::Value,
) -> Result<T, Error> {
self.request(reqwest::Method::POST, path, Some(body)).await
}
async fn put<T: DeserializeOwned>(
&self,
path: &str,
body: serde_json::Value,
) -> Result<T, Error> {
self.request(reqwest::Method::PUT, path, Some(body)).await
}
async fn delete_req(&self, path: &str) -> Result<serde_json::Value, Error> {
self.request(reqwest::Method::DELETE, path, None).await
}
async fn poll_session(&self, agent_id: &Uuid, session_id: &Uuid) -> Result<Session, Error> {
let path = format!("/agents/{}/sessions/{}", agent_id, session_id);
let deadline = tokio::time::Instant::now() + self.poll_config.timeout;
loop {
tokio::time::sleep(self.poll_config.interval).await;
if tokio::time::Instant::now() > deadline {
return Err(Error::Timeout(self.poll_config.timeout));
}
let session: Session = self.get(&path).await?;
debug!(
session_id = %session_id,
status = %session.status,
"polling session"
);
match session.status.as_str() {
"completed" => return Ok(session),
"failed" => {
let msg = session.error.unwrap_or_else(|| "unknown error".into());
return Err(Error::SessionFailed(msg));
}
_ => continue,
}
}
}
pub async fn list_agents(&self) -> Result<Vec<Agent>, Error> {
self.get("/agents").await
}
pub async fn create_agent(&self, name: &str) -> Result<Agent, Error> {
self.post("/agents", json!({ "name": name })).await
}
pub async fn get_agent(&self, agent_id: &Uuid) -> Result<Agent, Error> {
self.get(&format!("/agents/{}", agent_id)).await
}
pub async fn update_agent(
&self,
agent_id: &Uuid,
name: Option<&str>,
description: Option<&str>,
) -> Result<Agent, Error> {
let mut body = json!({});
if let Some(n) = name {
body["name"] = json!(n);
}
if let Some(d) = description {
body["description"] = json!(d);
}
self.put(&format!("/agents/{}", agent_id), body).await
}
pub async fn delete_agent(&self, agent_id: &Uuid) -> Result<serde_json::Value, Error> {
self.delete_req(&format!("/agents/{}", agent_id)).await
}
pub async fn set_agent_active(
&self,
agent_id: &Uuid,
active: bool,
) -> Result<serde_json::Value, Error> {
self.put(
&format!("/agents/{}/active", agent_id),
json!({ "active": active }),
)
.await
}
pub async fn query(
&self,
agent_id: &Uuid,
message: &str,
) -> Result<Session, Error> {
self.query_with_persona(agent_id, message, None).await
}
pub async fn query_with_persona(
&self,
agent_id: &Uuid,
message: &str,
persona: Option<&str>,
) -> Result<Session, Error> {
let mut body = json!({ "message": message });
if let Some(p) = persona {
body["persona"] = json!(p);
}
let session: Session = self
.post(&format!("/agents/{}/query", agent_id), body)
.await?;
self.poll_session(agent_id, &session.id).await
}
pub async fn get_hooks(&self, agent_id: &Uuid) -> Result<HooksResponse, Error> {
self.get(&format!("/agents/{}/hooks", agent_id)).await
}
pub async fn fire_hook(
&self,
agent_id: &Uuid,
event: &str,
payload: serde_json::Value,
) -> Result<Session, Error> {
self.post(
&format!("/agents/{}/fire_hook", agent_id),
json!({ "event": event, "payload": payload }),
)
.await
}
pub async fn fire_hook_and_wait(
&self,
agent_id: &Uuid,
event: &str,
payload: serde_json::Value,
) -> Result<Session, Error> {
let session = self.fire_hook(agent_id, event, payload).await?;
self.poll_session(agent_id, &session.id).await
}
pub async fn list_sessions(
&self,
agent_id: &Uuid,
limit: Option<u32>,
) -> Result<Vec<Session>, Error> {
let path = match limit {
Some(l) => format!("/agents/{}/sessions?limit={}", agent_id, l),
None => format!("/agents/{}/sessions", agent_id),
};
self.get(&path).await
}
pub async fn get_session(
&self,
agent_id: &Uuid,
session_id: &Uuid,
) -> Result<Session, Error> {
self.get(&format!("/agents/{}/sessions/{}", agent_id, session_id))
.await
}
pub async fn install_agent_pack(
&self,
agent_id: &Uuid,
content_hash: &str,
) -> Result<serde_json::Value, Error> {
self.put(
&format!("/agents/{}/agent-pack", agent_id),
json!({ "content_hash": content_hash }),
)
.await
}
pub async fn provision_pack(
&self,
agent_id: &Uuid,
pack_definition: serde_json::Value,
) -> Result<ProvisionResult, Error> {
self.post(
&format!("/agents/{}/provision-pack", agent_id),
pack_definition,
)
.await
}
pub async fn list_integrations(
&self,
agent_id: &Uuid,
) -> Result<Vec<Integration>, Error> {
self.get(&format!("/agents/{}/integrations", agent_id))
.await
}
pub async fn create_integration(
&self,
agent_id: &Uuid,
platform: &str,
) -> Result<Integration, Error> {
self.post(
&format!("/agents/{}/integrations", agent_id),
json!({ "platform": platform }),
)
.await
}
pub async fn delete_integration(
&self,
agent_id: &Uuid,
integration_id: &Uuid,
) -> Result<serde_json::Value, Error> {
self.delete_req(&format!(
"/agents/{}/integrations/{}",
agent_id, integration_id
))
.await
}
pub async fn list_tasks(&self, agent_id: &Uuid) -> Result<Vec<AgentTask>, Error> {
self.get(&format!("/agents/{}/tasks", agent_id)).await
}
pub async fn create_task(
&self,
agent_id: &Uuid,
name: &str,
hook_event: Option<&str>,
schedule: Option<&str>,
) -> Result<AgentTask, Error> {
let mut body = json!({ "name": name });
if let Some(e) = hook_event {
body["hook_event"] = json!(e);
}
if let Some(s) = schedule {
body["schedule"] = json!(s);
}
self.post(&format!("/agents/{}/tasks", agent_id), body)
.await
}
pub async fn list_memories(
&self,
agent_id: &Uuid,
limit: Option<u32>,
offset: Option<u32>,
) -> Result<serde_json::Value, Error> {
let mut params = vec![];
if let Some(l) = limit {
params.push(format!("limit={}", l));
}
if let Some(o) = offset {
params.push(format!("offset={}", o));
}
let qs = if params.is_empty() {
String::new()
} else {
format!("?{}", params.join("&"))
};
self.get(&format!("/agents/{}/memories{}", agent_id, qs))
.await
}
pub async fn get_subscription_status(&self) -> Result<serde_json::Value, Error> {
self.get("/subscriptions/status").await
}
}