1use std::collections::HashMap;
16use std::path::PathBuf;
17use std::sync::Arc;
18
19use newt_core::SessionId;
20use serde::{Deserialize, Serialize};
21use serde_json::Value;
22use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
23use tokio::sync::Mutex;
24
25#[derive(Debug, Clone)]
27pub struct Session {
28 pub workspace_path: PathBuf,
30 pub model_override: Option<String>,
32 pub coder_enabled: bool,
39}
40
41pub struct AcpServer {
44 sessions: Arc<Mutex<HashMap<SessionId, Session>>>,
45 backend: Arc<dyn newt_inference::InferenceBackend>,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
66pub struct TaskReply {
67 pub model_id: String,
69 pub content: String,
71 pub diff: String,
73 pub empty_diff: bool,
75 pub diff_applied: bool,
78 #[serde(default, skip_serializing_if = "Option::is_none")]
86 pub emission_shape: Option<String>,
87 #[serde(default, skip_serializing_if = "Option::is_none")]
95 pub raw_emission: Option<String>,
96}
97
98impl TaskReply {
99 pub fn new(
102 model_id: impl Into<String>,
103 content: impl Into<String>,
104 diff: impl Into<String>,
105 diff_applied: bool,
106 ) -> anyhow::Result<Self> {
107 let model_id = model_id.into();
108 if model_id.is_empty() {
109 anyhow::bail!("TaskReply.model_id is mandatory and must not be empty");
110 }
111 let diff = diff.into();
112 let empty_diff = crate::diff::is_empty_diff(&diff);
113 Ok(Self {
114 model_id,
115 content: content.into(),
116 diff,
117 empty_diff,
118 diff_applied,
119 emission_shape: None,
120 raw_emission: None,
121 })
122 }
123
124 #[must_use]
127 pub fn with_emission_shape(mut self, shape: impl Into<String>) -> Self {
128 self.emission_shape = Some(shape.into());
129 self
130 }
131
132 #[must_use]
135 pub fn with_raw_emission(mut self, raw: impl Into<String>) -> Self {
136 self.raw_emission = Some(raw.into());
137 self
138 }
139}
140
141impl AcpServer {
142 pub fn new(backend: Arc<dyn newt_inference::InferenceBackend>) -> Self {
144 Self {
145 sessions: Arc::new(Mutex::new(HashMap::new())),
146 backend,
147 }
148 }
149
150 pub async fn run_stdio(self) -> anyhow::Result<()> {
152 self.run(tokio::io::stdin(), tokio::io::stdout()).await
153 }
154
155 pub async fn run<R, W>(self, reader: R, mut writer: W) -> anyhow::Result<()>
157 where
158 R: tokio::io::AsyncRead + Unpin,
159 W: tokio::io::AsyncWrite + Unpin,
160 {
161 let buf = BufReader::new(reader);
162 let mut lines = buf.lines();
163
164 while let Some(line) = lines.next_line().await? {
165 if line.trim().is_empty() {
166 continue;
167 }
168
169 let request: Value = match serde_json::from_str(&line) {
170 Ok(v) => v,
171 Err(e) => {
172 let resp = error_response(Value::Null, -32700, &format!("Parse error: {e}"));
173 write_response(&mut writer, &resp).await?;
174 continue;
175 }
176 };
177
178 let id = request.get("id").cloned().unwrap_or(Value::Null);
179 let method = request.get("method").and_then(|m| m.as_str()).unwrap_or("");
180 let params = request.get("params").cloned().unwrap_or(Value::Null);
181
182 let response = match self.handle(method, params).await {
183 Ok(result) => serde_json::json!({
184 "jsonrpc": "2.0",
185 "id": id,
186 "result": result,
187 }),
188 Err(e) => error_response(id, -32603, &e.to_string()),
189 };
190
191 write_response(&mut writer, &response).await?;
192 }
193
194 Ok(())
195 }
196
197 async fn handle(&self, method: &str, params: Value) -> anyhow::Result<Value> {
199 match method {
200 "initialize" => self.handle_initialize(params).await,
201 "new_session" => self.handle_new_session(params).await,
202 "set_session_model" => self.handle_set_session_model(params).await,
203 "prompt" => self.handle_prompt(params).await,
204 _ => anyhow::bail!("method not found: {method}"),
205 }
206 }
207
208 async fn handle_initialize(&self, _params: Value) -> anyhow::Result<Value> {
210 Ok(serde_json::json!({
211 "protocolVersion": "v0.1",
212 "serverInfo": {
213 "name": "newt-acp-worker",
214 "version": env!("CARGO_PKG_VERSION"),
215 },
216 "capabilities": {
217 "prompting": true,
218 "diff_capture": true,
219 },
220 }))
221 }
222
223 async fn handle_new_session(&self, params: Value) -> anyhow::Result<Value> {
231 let workspace_path: PathBuf = params
232 .get("workspace_path")
233 .and_then(|p| p.as_str())
234 .map(PathBuf::from)
235 .ok_or_else(|| anyhow::anyhow!("workspace_path required"))?;
236
237 if !workspace_path.exists() {
238 anyhow::bail!(
239 "workspace_path does not exist: {}",
240 workspace_path.display()
241 );
242 }
243
244 let env_coder = std::env::var("NEWT_CODER")
245 .map(|v| v == "1")
246 .unwrap_or(false);
247 let param_coder = params
248 .get("coder")
249 .and_then(|v| v.as_bool())
250 .unwrap_or(false);
251 let coder_enabled = env_coder || param_coder;
252
253 let session_id = SessionId::new();
254 let mut sessions = self.sessions.lock().await;
255 sessions.insert(
256 session_id,
257 Session {
258 workspace_path,
259 model_override: None,
260 coder_enabled,
261 },
262 );
263
264 Ok(serde_json::json!({
265 "session_id": session_id.to_string(),
266 "coder": coder_enabled,
267 }))
268 }
269
270 async fn handle_set_session_model(&self, params: Value) -> anyhow::Result<Value> {
273 let session_id: SessionId = params
274 .get("session_id")
275 .and_then(|s| s.as_str())
276 .ok_or_else(|| anyhow::anyhow!("session_id required"))?
277 .parse()?;
278 let model = params
279 .get("model")
280 .and_then(|m| m.as_str())
281 .ok_or_else(|| anyhow::anyhow!("model required"))?
282 .to_string();
283
284 let mut sessions = self.sessions.lock().await;
285 let session = sessions
286 .get_mut(&session_id)
287 .ok_or_else(|| anyhow::anyhow!("unknown session: {session_id}"))?;
288 session.model_override = Some(model);
289
290 Ok(serde_json::json!({ "ok": true }))
291 }
292
293 async fn handle_prompt(&self, params: Value) -> anyhow::Result<Value> {
317 let session_id: SessionId = params
318 .get("session_id")
319 .and_then(|s| s.as_str())
320 .ok_or_else(|| anyhow::anyhow!("session_id required"))?
321 .parse()?;
322 let prompt = params
323 .get("prompt")
324 .and_then(|p| p.as_str())
325 .ok_or_else(|| anyhow::anyhow!("prompt required"))?
326 .to_string();
327
328 let session = {
329 let sessions = self.sessions.lock().await;
330 sessions
331 .get(&session_id)
332 .cloned()
333 .ok_or_else(|| anyhow::anyhow!("unknown session: {session_id}"))?
334 };
335
336 let task_reply = if session.coder_enabled {
337 self.handle_prompt_coder(&session, &prompt).await?
338 } else {
339 self.handle_prompt_flat(&session, &prompt).await?
340 };
341
342 Ok(serde_json::to_value(task_reply)?)
343 }
344
345 async fn handle_prompt_flat(
349 &self,
350 session: &Session,
351 prompt: &str,
352 ) -> anyhow::Result<TaskReply> {
353 let req = newt_inference::ChatRequest::new()
354 .system("You are a coding assistant. Respond with unified diffs only.")
355 .user(prompt.to_string());
356
357 let reply = self.backend.complete(req).await?;
358
359 let diff_applied = if looks_like_unified_diff(&reply.content) {
364 match newt_tools::apply_patch(&session.workspace_path, &reply.content) {
365 Ok(()) => true,
366 Err(e) => {
367 tracing::warn!(error = %e, "patch application failed");
368 false
369 }
370 }
371 } else {
372 false
373 };
374
375 let diff = crate::diff::capture_diff(&session.workspace_path)?;
376 let raw_emission = reply.content.clone();
379 TaskReply::new(reply.model_id, reply.content, diff, diff_applied)
380 .map(|r| r.with_raw_emission(raw_emission))
381 .map_err(|e| anyhow::anyhow!("backend returned malformed reply: {e}"))
382 }
383
384 async fn handle_prompt_coder(
387 &self,
388 session: &Session,
389 prompt: &str,
390 ) -> anyhow::Result<TaskReply> {
391 let coder = newt_coder::Coder::new(Arc::clone(&self.backend));
392 let caveats = newt_core::Caveats::top();
399 let run = coder
400 .run(&session.workspace_path, prompt, &caveats)
401 .await
402 .map_err(|e| anyhow::anyhow!("newt-coder run failed: {e}"))?;
403
404 let diff = crate::diff::capture_diff(&session.workspace_path)?;
407 let diff_applied = !run.files_written.is_empty() || !diff.trim().is_empty();
408
409 let content = format!(
410 "[newt-coder] {} file(s) written via {}",
411 run.files_written.len(),
412 run.emission_shape,
413 );
414
415 Ok(TaskReply::new(run.model_id, content, diff, diff_applied)
416 .map_err(|e| anyhow::anyhow!("newt-coder returned malformed reply: {e}"))?
417 .with_emission_shape(run.emission_shape)
418 .with_raw_emission(run.first_emission))
419 }
420}
421
422fn looks_like_unified_diff(content: &str) -> bool {
426 content.contains("--- ") && content.contains("+++ ")
427}
428
429async fn write_response<W: tokio::io::AsyncWrite + Unpin>(
431 writer: &mut W,
432 response: &Value,
433) -> anyhow::Result<()> {
434 let mut out = serde_json::to_string(response)?;
435 out.push('\n');
436 writer.write_all(out.as_bytes()).await?;
437 writer.flush().await?;
438 Ok(())
439}
440
441fn error_response(id: Value, code: i32, message: &str) -> Value {
443 serde_json::json!({
444 "jsonrpc": "2.0",
445 "id": id,
446 "error": { "code": code, "message": message },
447 })
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453
454 #[test]
455 fn task_reply_rejects_empty_model_id() {
456 let err = TaskReply::new("", "content", "", false).unwrap_err();
457 assert!(
458 err.to_string().contains("mandatory"),
459 "expected mandatory-id error, got: {err}"
460 );
461 }
462
463 #[test]
464 fn task_reply_accepts_nonempty_model_id() {
465 let r = TaskReply::new("qwen2.5-coder:32b", "hi", "", false).unwrap();
466 assert_eq!(r.model_id, "qwen2.5-coder:32b");
467 assert_eq!(r.content, "hi");
468 }
469
470 #[test]
471 fn task_reply_sets_empty_diff_from_diff_string() {
472 let r = TaskReply::new("m", "c", "", false).unwrap();
473 assert!(r.empty_diff);
474
475 let r = TaskReply::new("m", "c", "real\nchanges\n", true).unwrap();
476 assert!(!r.empty_diff);
477 }
478
479 #[test]
480 fn task_reply_serde_round_trip_preserves_model_id() {
481 let r = TaskReply::new("m", "c", "d\n", true).unwrap();
482 let json = serde_json::to_string(&r).unwrap();
483 assert!(json.contains("\"model_id\":\"m\""));
485 let back: TaskReply = serde_json::from_str(&json).unwrap();
486 assert_eq!(back, r);
487 }
488
489 #[test]
490 fn task_reply_deserialize_without_model_id_fails() {
491 let bad = r#"{"content":"c","diff":"","empty_diff":true,"diff_applied":false}"#;
494 let err = serde_json::from_str::<TaskReply>(bad).unwrap_err();
495 assert!(
496 err.to_string().contains("model_id"),
497 "expected missing-model_id error, got: {err}"
498 );
499 }
500
501 #[test]
502 fn task_reply_emission_shape_defaults_none() {
503 let r = TaskReply::new("m", "c", "", false).unwrap();
504 assert_eq!(r.emission_shape, None);
505 }
506
507 #[test]
508 fn task_reply_with_emission_shape_builder() {
509 let r = TaskReply::new("m", "c", "", false)
510 .unwrap()
511 .with_emission_shape("whole_files");
512 assert_eq!(r.emission_shape.as_deref(), Some("whole_files"));
513 }
514
515 #[test]
516 fn task_reply_omits_null_emission_shape_from_wire() {
517 let r = TaskReply::new("m", "c", "", false).unwrap();
521 let json = serde_json::to_string(&r).unwrap();
522 assert!(
523 !json.contains("emission_shape"),
524 "expected emission_shape omitted when None, got: {json}"
525 );
526 }
527
528 #[test]
529 fn task_reply_carries_emission_shape_on_wire_when_set() {
530 let r = TaskReply::new("m", "c", "", true)
531 .unwrap()
532 .with_emission_shape("whole_files");
533 let json = serde_json::to_string(&r).unwrap();
534 assert!(json.contains("\"emission_shape\":\"whole_files\""));
535 let back: TaskReply = serde_json::from_str(&json).unwrap();
536 assert_eq!(back.emission_shape.as_deref(), Some("whole_files"));
537 }
538
539 #[test]
540 fn task_reply_old_wire_without_emission_shape_still_parses() {
541 let old =
544 r#"{"model_id":"m","content":"c","diff":"","empty_diff":true,"diff_applied":false}"#;
545 let r: TaskReply = serde_json::from_str(old).unwrap();
546 assert_eq!(r.model_id, "m");
547 assert_eq!(r.emission_shape, None);
548 }
549
550 #[test]
551 fn looks_like_unified_diff_detects_headers() {
552 assert!(looks_like_unified_diff(
553 "--- a/f\n+++ b/f\n@@ -1,1 +1,1 @@\n-a\n+b\n"
554 ));
555 assert!(!looks_like_unified_diff("just prose"));
556 assert!(!looks_like_unified_diff("--- only the old header"));
557 }
558}