use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use newt_core::SessionId;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::sync::Mutex;
#[derive(Debug, Clone)]
pub struct Session {
pub workspace_path: PathBuf,
pub model_override: Option<String>,
pub coder_enabled: bool,
}
pub struct AcpServer {
sessions: Arc<Mutex<HashMap<SessionId, Session>>>,
backend: Arc<dyn newt_inference::InferenceBackend>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct TaskReply {
pub model_id: String,
pub content: String,
pub diff: String,
pub empty_diff: bool,
pub diff_applied: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub emission_shape: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub raw_emission: Option<String>,
}
impl TaskReply {
pub fn new(
model_id: impl Into<String>,
content: impl Into<String>,
diff: impl Into<String>,
diff_applied: bool,
) -> anyhow::Result<Self> {
let model_id = model_id.into();
if model_id.is_empty() {
anyhow::bail!("TaskReply.model_id is mandatory and must not be empty");
}
let diff = diff.into();
let empty_diff = crate::diff::is_empty_diff(&diff);
Ok(Self {
model_id,
content: content.into(),
diff,
empty_diff,
diff_applied,
emission_shape: None,
raw_emission: None,
})
}
#[must_use]
pub fn with_emission_shape(mut self, shape: impl Into<String>) -> Self {
self.emission_shape = Some(shape.into());
self
}
#[must_use]
pub fn with_raw_emission(mut self, raw: impl Into<String>) -> Self {
self.raw_emission = Some(raw.into());
self
}
}
impl AcpServer {
pub fn new(backend: Arc<dyn newt_inference::InferenceBackend>) -> Self {
Self {
sessions: Arc::new(Mutex::new(HashMap::new())),
backend,
}
}
pub async fn run_stdio(self) -> anyhow::Result<()> {
self.run(tokio::io::stdin(), tokio::io::stdout()).await
}
pub async fn run<R, W>(self, reader: R, mut writer: W) -> anyhow::Result<()>
where
R: tokio::io::AsyncRead + Unpin,
W: tokio::io::AsyncWrite + Unpin,
{
let buf = BufReader::new(reader);
let mut lines = buf.lines();
while let Some(line) = lines.next_line().await? {
if line.trim().is_empty() {
continue;
}
let request: Value = match serde_json::from_str(&line) {
Ok(v) => v,
Err(e) => {
let resp = error_response(Value::Null, -32700, &format!("Parse error: {e}"));
write_response(&mut writer, &resp).await?;
continue;
}
};
let id = request.get("id").cloned().unwrap_or(Value::Null);
let method = request.get("method").and_then(|m| m.as_str()).unwrap_or("");
let params = request.get("params").cloned().unwrap_or(Value::Null);
let response = match self.handle(method, params).await {
Ok(result) => serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"result": result,
}),
Err(e) => error_response(id, -32603, &e.to_string()),
};
write_response(&mut writer, &response).await?;
}
Ok(())
}
async fn handle(&self, method: &str, params: Value) -> anyhow::Result<Value> {
match method {
"initialize" => self.handle_initialize(params).await,
"new_session" => self.handle_new_session(params).await,
"set_session_model" => self.handle_set_session_model(params).await,
"prompt" => self.handle_prompt(params).await,
_ => anyhow::bail!("method not found: {method}"),
}
}
async fn handle_initialize(&self, _params: Value) -> anyhow::Result<Value> {
Ok(serde_json::json!({
"protocolVersion": "v0.1",
"serverInfo": {
"name": "newt-acp-worker",
"version": env!("CARGO_PKG_VERSION"),
},
"capabilities": {
"prompting": true,
"diff_capture": true,
},
}))
}
async fn handle_new_session(&self, params: Value) -> anyhow::Result<Value> {
let workspace_path: PathBuf = params
.get("workspace_path")
.and_then(|p| p.as_str())
.map(PathBuf::from)
.ok_or_else(|| anyhow::anyhow!("workspace_path required"))?;
if !workspace_path.exists() {
anyhow::bail!(
"workspace_path does not exist: {}",
workspace_path.display()
);
}
let env_coder = std::env::var("NEWT_CODER")
.map(|v| v == "1")
.unwrap_or(false);
let param_coder = params
.get("coder")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let coder_enabled = env_coder || param_coder;
let session_id = SessionId::new();
let mut sessions = self.sessions.lock().await;
sessions.insert(
session_id,
Session {
workspace_path,
model_override: None,
coder_enabled,
},
);
Ok(serde_json::json!({
"session_id": session_id.to_string(),
"coder": coder_enabled,
}))
}
async fn handle_set_session_model(&self, params: Value) -> anyhow::Result<Value> {
let session_id: SessionId = params
.get("session_id")
.and_then(|s| s.as_str())
.ok_or_else(|| anyhow::anyhow!("session_id required"))?
.parse()?;
let model = params
.get("model")
.and_then(|m| m.as_str())
.ok_or_else(|| anyhow::anyhow!("model required"))?
.to_string();
let mut sessions = self.sessions.lock().await;
let session = sessions
.get_mut(&session_id)
.ok_or_else(|| anyhow::anyhow!("unknown session: {session_id}"))?;
session.model_override = Some(model);
Ok(serde_json::json!({ "ok": true }))
}
async fn handle_prompt(&self, params: Value) -> anyhow::Result<Value> {
let session_id: SessionId = params
.get("session_id")
.and_then(|s| s.as_str())
.ok_or_else(|| anyhow::anyhow!("session_id required"))?
.parse()?;
let prompt = params
.get("prompt")
.and_then(|p| p.as_str())
.ok_or_else(|| anyhow::anyhow!("prompt required"))?
.to_string();
let session = {
let sessions = self.sessions.lock().await;
sessions
.get(&session_id)
.cloned()
.ok_or_else(|| anyhow::anyhow!("unknown session: {session_id}"))?
};
let task_reply = if session.coder_enabled {
self.handle_prompt_coder(&session, &prompt).await?
} else {
self.handle_prompt_flat(&session, &prompt).await?
};
Ok(serde_json::to_value(task_reply)?)
}
async fn handle_prompt_flat(
&self,
session: &Session,
prompt: &str,
) -> anyhow::Result<TaskReply> {
let req = newt_inference::ChatRequest::new()
.system("You are a coding assistant. Respond with unified diffs only.")
.user(prompt.to_string());
let reply = self.backend.complete(req).await?;
let diff_applied = if looks_like_unified_diff(&reply.content) {
match newt_tools::apply_patch(&session.workspace_path, &reply.content) {
Ok(()) => true,
Err(e) => {
tracing::warn!(error = %e, "patch application failed");
false
}
}
} else {
false
};
let diff = crate::diff::capture_diff(&session.workspace_path)?;
let raw_emission = reply.content.clone();
TaskReply::new(reply.model_id, reply.content, diff, diff_applied)
.map(|r| r.with_raw_emission(raw_emission))
.map_err(|e| anyhow::anyhow!("backend returned malformed reply: {e}"))
}
async fn handle_prompt_coder(
&self,
session: &Session,
prompt: &str,
) -> anyhow::Result<TaskReply> {
let coder = newt_coder::Coder::new(Arc::clone(&self.backend));
let caveats = newt_core::Caveats::top();
let run = coder
.run(&session.workspace_path, prompt, &caveats)
.await
.map_err(|e| anyhow::anyhow!("newt-coder run failed: {e}"))?;
let diff = crate::diff::capture_diff(&session.workspace_path)?;
let diff_applied = !run.files_written.is_empty() || !diff.trim().is_empty();
let content = format!(
"[newt-coder] {} file(s) written via {}",
run.files_written.len(),
run.emission_shape,
);
Ok(TaskReply::new(run.model_id, content, diff, diff_applied)
.map_err(|e| anyhow::anyhow!("newt-coder returned malformed reply: {e}"))?
.with_emission_shape(run.emission_shape)
.with_raw_emission(run.first_emission))
}
}
fn looks_like_unified_diff(content: &str) -> bool {
content.contains("--- ") && content.contains("+++ ")
}
async fn write_response<W: tokio::io::AsyncWrite + Unpin>(
writer: &mut W,
response: &Value,
) -> anyhow::Result<()> {
let mut out = serde_json::to_string(response)?;
out.push('\n');
writer.write_all(out.as_bytes()).await?;
writer.flush().await?;
Ok(())
}
fn error_response(id: Value, code: i32, message: &str) -> Value {
serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"error": { "code": code, "message": message },
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn task_reply_rejects_empty_model_id() {
let err = TaskReply::new("", "content", "", false).unwrap_err();
assert!(
err.to_string().contains("mandatory"),
"expected mandatory-id error, got: {err}"
);
}
#[test]
fn task_reply_accepts_nonempty_model_id() {
let r = TaskReply::new("qwen2.5-coder:32b", "hi", "", false).unwrap();
assert_eq!(r.model_id, "qwen2.5-coder:32b");
assert_eq!(r.content, "hi");
}
#[test]
fn task_reply_sets_empty_diff_from_diff_string() {
let r = TaskReply::new("m", "c", "", false).unwrap();
assert!(r.empty_diff);
let r = TaskReply::new("m", "c", "real\nchanges\n", true).unwrap();
assert!(!r.empty_diff);
}
#[test]
fn task_reply_serde_round_trip_preserves_model_id() {
let r = TaskReply::new("m", "c", "d\n", true).unwrap();
let json = serde_json::to_string(&r).unwrap();
assert!(json.contains("\"model_id\":\"m\""));
let back: TaskReply = serde_json::from_str(&json).unwrap();
assert_eq!(back, r);
}
#[test]
fn task_reply_deserialize_without_model_id_fails() {
let bad = r#"{"content":"c","diff":"","empty_diff":true,"diff_applied":false}"#;
let err = serde_json::from_str::<TaskReply>(bad).unwrap_err();
assert!(
err.to_string().contains("model_id"),
"expected missing-model_id error, got: {err}"
);
}
#[test]
fn task_reply_emission_shape_defaults_none() {
let r = TaskReply::new("m", "c", "", false).unwrap();
assert_eq!(r.emission_shape, None);
}
#[test]
fn task_reply_with_emission_shape_builder() {
let r = TaskReply::new("m", "c", "", false)
.unwrap()
.with_emission_shape("whole_files");
assert_eq!(r.emission_shape.as_deref(), Some("whole_files"));
}
#[test]
fn task_reply_omits_null_emission_shape_from_wire() {
let r = TaskReply::new("m", "c", "", false).unwrap();
let json = serde_json::to_string(&r).unwrap();
assert!(
!json.contains("emission_shape"),
"expected emission_shape omitted when None, got: {json}"
);
}
#[test]
fn task_reply_carries_emission_shape_on_wire_when_set() {
let r = TaskReply::new("m", "c", "", true)
.unwrap()
.with_emission_shape("whole_files");
let json = serde_json::to_string(&r).unwrap();
assert!(json.contains("\"emission_shape\":\"whole_files\""));
let back: TaskReply = serde_json::from_str(&json).unwrap();
assert_eq!(back.emission_shape.as_deref(), Some("whole_files"));
}
#[test]
fn task_reply_old_wire_without_emission_shape_still_parses() {
let old =
r#"{"model_id":"m","content":"c","diff":"","empty_diff":true,"diff_applied":false}"#;
let r: TaskReply = serde_json::from_str(old).unwrap();
assert_eq!(r.model_id, "m");
assert_eq!(r.emission_shape, None);
}
#[test]
fn looks_like_unified_diff_detects_headers() {
assert!(looks_like_unified_diff(
"--- a/f\n+++ b/f\n@@ -1,1 +1,1 @@\n-a\n+b\n"
));
assert!(!looks_like_unified_diff("just prose"));
assert!(!looks_like_unified_diff("--- only the old header"));
}
}