use crate::Session;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum MessageRole {
System,
User,
Assistant,
}
impl MessageRole {
pub fn as_str(&self) -> &'static str {
match self {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
}
}
}
impl std::fmt::Display for MessageRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: MessageRole,
pub content: String,
}
impl ChatMessage {
pub fn new(role: MessageRole, content: impl Into<String>) -> Self {
Self {
role,
content: content.into(),
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum LlmProvider {
OpenAI,
Anthropic,
Ollama,
OpenCode,
}
#[derive(Clone, Debug)]
pub enum StreamChunk {
Content(String),
Done,
}
#[derive(Clone, Debug)]
pub struct ModelInfo {
pub id: String,
pub name: String,
pub provider_id: String,
}
pub struct AgentClientSession {
provider: LlmProvider,
base_url: String,
api_key: Option<String>,
model: Option<String>,
session: Session,
messages: Vec<ChatMessage>,
opencode_session_id: Option<String>,
opencode_parent_id: Option<String>,
}
impl AgentClientSession {
pub fn new(
provider: LlmProvider,
base_url: impl Into<String>,
api_key: Option<String>,
) -> Self {
Self {
provider,
base_url: base_url.into(),
api_key,
model: None,
session: Session::new(),
messages: Vec::new(),
opencode_session_id: None,
opencode_parent_id: None,
}
}
pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
self.messages
.push(ChatMessage::new(MessageRole::System, prompt));
}
pub async fn set_model(&mut self, model: impl Into<String>) -> anyhow::Result<()> {
let model = model.into();
if self.provider != LlmProvider::Anthropic {
let available_models = self.list_models().await?;
let exists = available_models.iter().any(|m| m.id == model);
if !exists {
return Err(anyhow::anyhow!(
"Model '{model}' is invalid. Use list_models() to get valid models"
));
}
}
self.model = Some(model);
Ok(())
}
pub fn model(&self) -> Option<&str> {
self.model.as_deref()
}
pub async fn list_models(&mut self) -> anyhow::Result<Vec<ModelInfo>> {
match self.provider {
LlmProvider::OpenCode => {
Self::list_models_opencode(&self.base_url, &mut self.session).await
}
LlmProvider::OpenAI => {
Self::list_models_openai(&self.base_url, &self.api_key, &mut self.session).await
}
LlmProvider::Ollama => {
Self::list_models_ollama(&self.base_url, &mut self.session).await
}
LlmProvider::Anthropic => Ok(vec![]), }
}
async fn list_models_opencode(
base_url: &str,
session: &mut Session,
) -> anyhow::Result<Vec<ModelInfo>> {
let url = format!("{}/config/providers", base_url);
let mut res = session.get(&url, vec![]).await?;
let body_data = res.body.data().await;
let response_text = String::from_utf8_lossy(body_data).to_string();
if res.http_code != 200 {
return Err(anyhow::anyhow!(
"Failed to list OpenCode models: HTTP {}",
res.http_code
));
}
let json: serde_json::Value = serde_json::from_str(&response_text)?;
let mut models = Vec::new();
if let Some(providers) = json["providers"].as_array() {
for provider in providers {
let provider_id = provider["id"].as_str().unwrap_or("unknown").to_string();
if let Some(provider_models) = provider["models"].as_object() {
for (model_id, model_info) in provider_models {
let name = model_info["name"].as_str().unwrap_or(model_id).to_string();
models.push(ModelInfo {
id: format!("{}:{}", provider_id, model_id),
name,
provider_id: provider_id.clone(),
});
}
}
}
}
Ok(models)
}
async fn list_models_openai(
base_url: &str,
api_key: &Option<String>,
session: &mut Session,
) -> anyhow::Result<Vec<ModelInfo>> {
let url = format!("{}/v1/models", base_url);
let mut headers = vec![("Content-Type".to_string(), "application/json".to_string())];
if let Some(ref key) = api_key {
headers.push(("Authorization".to_string(), format!("Bearer {key}")));
}
let mut args = Vec::new();
for (k, v) in headers {
args.push(crate::Headers::Custom((k, v)));
}
let mut res = session.get(&url, args).await?;
let body_data = res.body.data().await;
let response_text = String::from_utf8_lossy(body_data).to_string();
if res.http_code != 200 {
return Err(anyhow::anyhow!(
"Failed to list OpenAI models: HTTP {}",
res.http_code
));
}
let json: serde_json::Value = serde_json::from_str(&response_text)?;
let mut models = Vec::new();
if let Some(data) = json["data"].as_array() {
for item in data {
let id = item["id"].as_str().unwrap_or("").to_string();
if !id.is_empty() {
models.push(ModelInfo {
id: id.clone(),
name: id,
provider_id: "openai".to_string(),
});
}
}
}
Ok(models)
}
async fn list_models_ollama(
base_url: &str,
session: &mut Session,
) -> anyhow::Result<Vec<ModelInfo>> {
let url = format!("{}/api/tags", base_url);
let mut res = session.get(&url, vec![]).await?;
let body_data = res.body.data().await;
let response_text = String::from_utf8_lossy(body_data).to_string();
if res.http_code != 200 {
return Err(anyhow::anyhow!(
"Failed to list Ollama models: HTTP {}",
res.http_code
));
}
let json: serde_json::Value = serde_json::from_str(&response_text)?;
let mut models = Vec::new();
if let Some(data) = json["models"].as_array() {
for item in data {
let id = item["name"].as_str().unwrap_or("").to_string();
if !id.is_empty() {
models.push(ModelInfo {
id: id.clone(),
name: id,
provider_id: "ollama".to_string(),
});
}
}
}
Ok(models)
}
fn ensure_model(&self) -> anyhow::Result<&str> {
self.model
.as_deref()
.ok_or_else(|| anyhow::anyhow!("Model not set. Call set_model() first."))
}
pub fn messages(&self) -> &[ChatMessage] {
&self.messages
}
pub fn clear_messages(&mut self) {
let system_msgs: Vec<ChatMessage> = self
.messages
.drain(..)
.filter(|m| m.role == MessageRole::System)
.collect();
self.messages = system_msgs;
}
pub async fn chat(&mut self, message: impl Into<String>) -> anyhow::Result<String> {
let user_msg = ChatMessage::new(MessageRole::User, message);
self.messages.push(user_msg);
if self.provider == LlmProvider::OpenCode {
return self.chat_opencode(false).await;
}
let (url, body, headers) = self.build_request(false)?;
let mut args = Vec::new();
for (k, v) in headers {
args.push(crate::Headers::Custom((k, v)));
}
let mut res = self.session.post_json(&url, body, args).await?;
let body_data = res.body.data().await;
let response_text = String::from_utf8_lossy(body_data).to_string();
if res.http_code != 200 {
return Err(anyhow::anyhow!(
"HTTP error {}: {}",
res.http_code,
response_text
));
}
let content = self.parse_response(&response_text)?;
self.messages
.push(ChatMessage::new(MessageRole::Assistant, content.clone()));
Ok(content)
}
pub fn session_mut(&mut self) -> &mut Session {
&mut self.session
}
async fn chat_opencode(&mut self, _stream: bool) -> anyhow::Result<String> {
if self.opencode_session_id.is_none() {
let create_url = format!("{}/session", self.base_url);
let create_body = serde_json::json!({"title": "potato-agent-session"});
let mut create_res = self
.session
.post_json(&create_url, create_body, vec![])
.await?;
let create_data = create_res.body.data().await;
let create_text = String::from_utf8_lossy(create_data).to_string();
if create_res.http_code != 200 {
return Err(anyhow::anyhow!(
"OpenCode create session failed {}: {}",
create_res.http_code,
create_text
));
}
let create_json: serde_json::Value = serde_json::from_str(&create_text)?;
self.opencode_session_id = Some(
create_json["id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("OpenCode session response missing id"))?
.to_string(),
);
}
let session_id = self.opencode_session_id.as_ref().unwrap();
let url = format!("{}/session/{}/message", self.base_url, session_id);
let last_msg = self
.messages
.last()
.ok_or_else(|| anyhow::anyhow!("No message to send"))?;
let parts = serde_json::json!([{"type": "text", "text": last_msg.content}]);
let (provider_id, model_id) = self.parse_opencode_model()?;
let mut body = serde_json::json!({
"parts": parts,
"model": {
"providerID": provider_id,
"modelID": model_id,
},
});
if let Some(ref parent_id) = self.opencode_parent_id {
body["parentID"] = serde_json::Value::String(parent_id.clone());
}
let mut response_text = String::new();
for attempt in 0..3 {
let mut res = self.session.post_json(&url, body.clone(), vec![]).await?;
let body_data = res.body.data().await;
response_text = String::from_utf8_lossy(body_data).to_string();
if res.http_code != 200 {
return Err(anyhow::anyhow!(
"OpenCode message failed {}: {}",
res.http_code,
response_text
));
}
if !response_text.trim().is_empty() {
break;
}
if attempt < 2 {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
self.session.force_reconnect();
}
}
let content = self.parse_opencode_response(&response_text)?;
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&response_text) {
if let Some(parent_id) = json["info"]["parentID"].as_str() {
self.opencode_parent_id = Some(parent_id.to_string());
}
}
self.messages
.push(ChatMessage::new(MessageRole::Assistant, content.clone()));
Ok(content)
}
fn parse_opencode_model(&self) -> anyhow::Result<(String, String)> {
let model = self.ensure_model()?;
if let Some(pos) = model.find(':') {
let provider_id = model[..pos].to_string();
let model_id = model[pos + 1..].to_string();
Ok((provider_id, model_id))
} else {
Ok(("opencode".to_string(), model.to_string()))
}
}
fn parse_opencode_response(&self, text: &str) -> anyhow::Result<String> {
if text.trim().is_empty() {
return Err(anyhow::anyhow!("OpenCode response is empty"));
}
let json: serde_json::Value = serde_json::from_str(text)?;
let parts = json["parts"]
.as_array()
.ok_or_else(|| anyhow::anyhow!("OpenCode response missing parts"))?;
let mut result = String::new();
for part in parts {
if let Some(text) = part["text"].as_str() {
result.push_str(text);
}
}
Ok(result)
}
pub async fn chat_stream(
&mut self,
message: impl Into<String>,
) -> anyhow::Result<tokio::sync::mpsc::Receiver<StreamChunk>> {
let user_msg = ChatMessage::new(MessageRole::User, message);
self.messages.push(user_msg);
if self.provider == LlmProvider::OpenCode {
let content = self.chat_opencode(false).await?;
let (tx, rx) = tokio::sync::mpsc::channel::<StreamChunk>(64);
tokio::spawn(async move {
for line in content.lines() {
if tx
.send(StreamChunk::Content(line.to_string()))
.await
.is_err()
{
return;
}
}
let _ = tx.send(StreamChunk::Done).await;
});
return Ok(rx);
}
let (url, body, headers) = self.build_request(true)?;
let mut args = Vec::new();
for (k, v) in headers {
args.push(crate::Headers::Custom((k, v)));
}
let mut res = self.session.post_json(&url, body, args).await?;
if res.http_code != 200 {
let body_data = res.body.data().await;
return Err(anyhow::anyhow!(
"HTTP error {}: {}",
res.http_code,
String::from_utf8_lossy(body_data)
));
}
let (tx, rx) = tokio::sync::mpsc::channel::<StreamChunk>(64);
let provider = self.provider.clone();
tokio::spawn(async move {
let mut stream = res.body.stream_data();
let mut buffer = String::new();
while let Some(chunk) = stream.next().await {
let text = String::from_utf8_lossy(&chunk);
buffer.push_str(&text);
match provider {
LlmProvider::OpenAI => {
while let Some(pos) = buffer.find("\n\n") {
let event = buffer[..pos].to_string();
buffer = buffer[pos + 2..].to_string();
if let Some(content) = Self::parse_openai_sse_chunk(&event) {
if content.is_empty() {
continue;
}
if tx.send(StreamChunk::Content(content)).await.is_err() {
return;
}
}
}
}
LlmProvider::Anthropic => {
while let Some(pos) = buffer.find("\n\n") {
let event = buffer[..pos].to_string();
buffer = buffer[pos + 2..].to_string();
if let Some(content) = Self::parse_anthropic_sse_chunk(&event) {
if content.is_empty() {
continue;
}
if tx.send(StreamChunk::Content(content)).await.is_err() {
return;
}
}
}
}
LlmProvider::Ollama => {
while let Some(pos) = buffer.find('\n') {
let line = buffer[..pos].to_string();
buffer = buffer[pos + 1..].to_string();
if let Some(content) = Self::parse_ollama_ndjson_chunk(&line) {
if content.is_empty() {
continue;
}
if tx.send(StreamChunk::Content(content)).await.is_err() {
return;
}
}
}
}
LlmProvider::OpenCode => {
}
}
}
let _ = tx.send(StreamChunk::Done).await;
});
Ok(rx)
}
pub fn append_assistant_message(&mut self, content: impl Into<String>) {
self.messages
.push(ChatMessage::new(MessageRole::Assistant, content));
}
pub fn serialize(&self) -> anyhow::Result<String> {
let state = serde_json::json!({
"provider": self.provider,
"base_url": self.base_url,
"api_key": self.api_key,
"model": self.model,
"messages": self.messages,
"opencode_session_id": self.opencode_session_id,
"opencode_parent_id": self.opencode_parent_id,
});
Ok(state.to_string())
}
pub fn deserialize(json: &str) -> anyhow::Result<Self> {
let state: serde_json::Value = serde_json::from_str(json)?;
let provider: LlmProvider = serde_json::from_value(
state
.get("provider")
.ok_or_else(|| anyhow::anyhow!("missing provider field"))?
.clone(),
)?;
let base_url = state
.get("base_url")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("missing base_url field"))?
.to_string();
let api_key = state
.get("api_key")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let model = state
.get("model")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let messages: Vec<ChatMessage> = serde_json::from_value(
state
.get("messages")
.ok_or_else(|| anyhow::anyhow!("missing messages field"))?
.clone(),
)?;
let opencode_session_id = state
.get("opencode_session_id")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let opencode_parent_id = state
.get("opencode_parent_id")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
Ok(Self {
provider,
base_url,
api_key,
model,
session: Session::new(),
messages,
opencode_session_id,
opencode_parent_id,
})
}
fn build_request(
&self,
stream: bool,
) -> anyhow::Result<(String, serde_json::Value, Vec<(String, String)>)> {
let mut headers = vec![("Content-Type".to_string(), "application/json".to_string())];
if let Some(ref key) = self.api_key {
headers.push(("Authorization".to_string(), format!("Bearer {key}")));
}
match self.provider {
LlmProvider::OpenAI => {
let url = format!("{}/v1/chat/completions", self.base_url);
let messages: Vec<serde_json::Value> = self
.messages
.iter()
.map(|m| {
serde_json::json!({
"role": m.role.as_str(),
"content": m.content,
})
})
.collect();
let body = serde_json::json!({
"model": self.model,
"messages": messages,
"stream": stream,
});
Ok((url, body, headers))
}
LlmProvider::Anthropic => {
let url = format!("{}/v1/messages", self.base_url);
let system_msg = self
.messages
.iter()
.find(|m| m.role == MessageRole::System)
.map(|m| m.content.clone());
let messages: Vec<serde_json::Value> = self
.messages
.iter()
.filter(|m| m.role != MessageRole::System)
.map(|m| {
serde_json::json!({
"role": m.role.as_str(),
"content": m.content,
})
})
.collect();
let mut body = serde_json::json!({
"model": self.model,
"messages": messages,
"max_tokens": 4096,
"stream": stream,
});
if let Some(system) = system_msg {
body["system"] = serde_json::Value::String(system);
}
headers.push((
"x-api-key".to_string(),
self.api_key.clone().unwrap_or_default(),
));
headers.push(("anthropic-version".to_string(), "2023-06-01".to_string()));
Ok((url, body, headers))
}
LlmProvider::Ollama => {
let url = format!("{}/api/chat", self.base_url);
let messages: Vec<serde_json::Value> = self
.messages
.iter()
.map(|m| {
serde_json::json!({
"role": m.role.as_str(),
"content": m.content,
})
})
.collect();
let body = serde_json::json!({
"model": self.model,
"messages": messages,
"stream": stream,
});
Ok((url, body, headers))
}
LlmProvider::OpenCode => {
let url = format!("{}/session/message", self.base_url);
let body = serde_json::json!({});
Ok((url, body, headers))
}
}
}
fn parse_response(&self, text: &str) -> anyhow::Result<String> {
match self.provider {
LlmProvider::OpenAI => {
let json: serde_json::Value = serde_json::from_str(text)?;
let content = json["choices"][0]["message"]["content"]
.as_str()
.unwrap_or("");
Ok(content.to_string())
}
LlmProvider::OpenCode => self.parse_opencode_response(text),
LlmProvider::Anthropic => {
let json: serde_json::Value = serde_json::from_str(text)?;
let mut result = String::new();
if let Some(contents) = json["content"].as_array() {
for item in contents {
if item["type"].as_str() == Some("text") {
if let Some(text) = item["text"].as_str() {
result.push_str(text);
}
}
}
}
Ok(result)
}
LlmProvider::Ollama => {
let json: serde_json::Value = serde_json::from_str(text)?;
let content = json["message"]["content"].as_str().unwrap_or("");
Ok(content.to_string())
}
}
}
fn parse_openai_sse_chunk(event: &str) -> Option<String> {
for line in event.lines() {
if line.starts_with("data: ") {
let data = &line[6..];
if data == "[DONE]" {
return Some(String::new());
}
if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
if let Some(content) = json["choices"][0]["delta"]["content"].as_str() {
return Some(content.to_string());
}
}
}
}
None
}
fn parse_anthropic_sse_chunk(event: &str) -> Option<String> {
for line in event.lines() {
if line.starts_with("data: ") {
let data = &line[6..];
if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
if let Some(text) = json["delta"]["text"].as_str() {
return Some(text.to_string());
}
}
}
}
None
}
fn parse_ollama_ndjson_chunk(line: &str) -> Option<String> {
if let Ok(json) = serde_json::from_str::<serde_json::Value>(line) {
if json["done"].as_bool().unwrap_or(false) {
return Some(String::new());
}
if let Some(content) = json["message"]["content"].as_str() {
return Some(content.to_string());
}
}
None
}
}