#![allow(dead_code)]
use anyhow::{Context, Result};
use clawgarden_proto::MessagePayload;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tokio::time::timeout;
const RPC_TIMEOUT_MS: u64 = 1600;
#[derive(Debug, Serialize, Deserialize)]
pub struct PiRpcRequest {
pub agent_name: String,
pub persona: String,
pub memory: String,
pub conversation_id: String,
pub correlation_id: String,
pub content: String,
pub recent_messages: Vec<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PiRpcResponse {
pub response: String,
pub confidence: f32,
}
pub async fn call_pi_rpc(request: PiRpcRequest) -> Result<MessagePayload> {
match call_pi_rpc_uds(&request).await {
Ok(response) => return Ok(response),
Err(e) => {
log::debug!("UDS call failed, trying HTTP: {}", e);
}
}
call_pi_rpc_http(&request).await
}
async fn call_pi_rpc_uds(request: &PiRpcRequest) -> Result<MessagePayload> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::UnixStream;
let sock_path = "/tmp/pi-engine.sock";
let mut stream = UnixStream::connect(sock_path)
.await
.context("Failed to connect to pi-engine UDS socket")?;
let data = serde_json::to_vec(request)?;
let len = data.len() as u32;
let mut buf = Vec::new();
buf.extend_from_slice(&len.to_be_bytes());
buf.extend_from_slice(&data);
stream.write_all(&buf).await?;
stream.flush().await?;
let mut len_buf = [0u8; 4];
stream.read_exact(&mut len_buf).await?;
let len = u32::from_be_bytes(len_buf) as usize;
let mut response_buf = vec![0u8; len];
stream.read_exact(&mut response_buf).await?;
let response: PiRpcResponse = serde_json::from_slice(&response_buf)?;
Ok(MessagePayload {
content: response.response,
context: vec![],
})
}
async fn call_pi_rpc_http(request: &PiRpcRequest) -> Result<MessagePayload> {
let client = reqwest::Client::new();
let url = "http://localhost:3001/rpc";
let response = timeout(
Duration::from_millis(RPC_TIMEOUT_MS),
client.post(url).json(request).send(),
)
.await
.context("pi RPC HTTP call timed out")?
.context("pi RPC HTTP call failed")?;
let rpc_response: PiRpcResponse = response.json().await?;
Ok(MessagePayload {
content: rpc_response.response,
context: vec![],
})
}
pub async fn call_pi_rpc_safe(request: PiRpcRequest) -> Result<MessagePayload> {
let result = timeout(
Duration::from_millis(RPC_TIMEOUT_MS + 200), call_pi_rpc(request),
)
.await;
match result {
Ok(Ok(payload)) => Ok(payload),
Ok(Err(e)) => {
log::error!("pi RPC call failed: {}", e);
Err(e)
}
Err(_) => {
log::error!("pi RPC call timed out after {}ms", RPC_TIMEOUT_MS);
Err(anyhow::anyhow!("pi RPC timeout"))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rpc_request_serialization() {
let request = PiRpcRequest {
agent_name: "alex".to_string(),
persona: "Role: pm".to_string(),
memory: String::new(),
conversation_id: "conv_1".to_string(),
correlation_id: "req_1".to_string(),
content: "Hello".to_string(),
recent_messages: vec![],
};
let json = serde_json::to_string(&request).unwrap();
let restored: PiRpcRequest = serde_json::from_str(&json).unwrap();
assert_eq!(restored.agent_name, "alex");
assert_eq!(restored.content, "Hello");
}
}