use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use serde_json::{json, Value};
use smooth_operator_core::extension::manifest::{default_global_dir, project_dir};
use smooth_operator_core::extension::protocol::{HostInfo, RpcError, WorkspaceInfo};
use smooth_operator_core::extension::{discover, DiscoveredExtension, ExtensionHost, HostDelegate};
use smooth_operator_core::HumanResponse;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use crate::runner::{ClearConfirmation, RegisterConfirmation};
use crate::state::AppState;
const UI_MODE: &str = "widget";
const UI_CONFIRM_TIMEOUT: Duration = Duration::from_secs(300);
pub struct ExtensionTurn {
pub host: Arc<ExtensionHost>,
pub session_id: String,
pub clear: ClearConfirmation,
}
struct ConfirmUiProvider {
sink: UnboundedSender<Value>,
request_id: String,
session_id: String,
register: RegisterConfirmation,
}
fn extension_allowed(name: &str, allow: &[String], enabled_extensions: Option<&[String]>) -> bool {
allow.iter().any(|a| a == name)
&& enabled_extensions.is_none_or(|ids| ids.iter().any(|id| id == name))
}
fn parse_allowlist(raw: Option<&str>) -> Vec<String> {
raw.unwrap_or_default()
.split(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.map(str::to_string)
.collect()
}
#[async_trait]
impl HostDelegate for ConfirmUiProvider {
async fn ui_request(&self, ext: &str, params: Value) -> Result<Value, RpcError> {
let kind = params
.get("kind")
.and_then(Value::as_str)
.unwrap_or_default();
match kind {
"confirm" => {
let prompt = params
.get("prompt")
.and_then(Value::as_str)
.unwrap_or("Confirm this action?");
let (tx, mut rx) = unbounded_channel::<HumanResponse>();
(self.register)(&self.session_id, tx);
let _ = self.sink.send(crate::protocol::write_confirmation_required(
&self.request_id,
ext,
prompt,
));
match tokio::time::timeout(UI_CONFIRM_TIMEOUT, rx.recv()).await {
Ok(Some(HumanResponse::Approved)) => Ok(json!({ "confirmed": true })),
Ok(Some(HumanResponse::Denied { .. })) => Ok(json!({ "confirmed": false })),
_ => Ok(json!({ "cancelled": true })),
}
}
"notify" | "set_status" | "set_widget" | "set_title" => Ok(json!({})),
_ => Ok(json!({ "cancelled": true })),
}
}
}
pub async fn build_extension_host(
state: &AppState,
session_id: &str,
request_id: &str,
sink: UnboundedSender<Value>,
enabled_extensions: Option<&[String]>,
) -> Option<ExtensionTurn> {
let allow = parse_allowlist(std::env::var("SMOOTH_EXTENSIONS_ALLOW").ok().as_deref());
if allow.is_empty() {
return None; }
let global = std::env::var("SMOOTH_EXTENSIONS_DIR")
.ok()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.map(std::path::PathBuf::from)
.or_else(default_global_dir);
let project = std::env::current_dir().ok().map(|d| project_dir(&d));
let (discovered, disc_failures) = discover(global.as_deref(), project.as_deref());
for (src, err) in &disc_failures {
tracing::warn!(%src, %err, "sep: extension manifest failed to parse");
}
let allowed: Vec<DiscoveredExtension> = discovered
.into_iter()
.filter(|ext| {
let ok = extension_allowed(&ext.manifest.name, &allow, enabled_extensions);
if !ok {
tracing::info!(name = %ext.manifest.name, "sep: skipping extension — not in SMOOTH_EXTENSIONS_ALLOW ∩ per-agent enabled extensions");
}
ok
})
.collect();
if allowed.is_empty() {
return None;
}
let host_info = HostInfo {
name: "smooth-operator-server".into(),
version: env!("CARGO_PKG_VERSION").into(),
};
let workspace = WorkspaceInfo {
root: std::env::current_dir()
.map(|d| d.to_string_lossy().into_owned())
.unwrap_or_default(),
trusted: true,
};
let register: RegisterConfirmation = {
let state = state.clone();
Arc::new(move |sid: &str, responder| state.register_confirmation(sid, responder))
};
let clear: ClearConfirmation = {
let state = state.clone();
Arc::new(move |sid: &str| state.clear_confirmation(sid))
};
let delegate = Arc::new(ConfirmUiProvider {
sink,
request_id: request_id.to_string(),
session_id: session_id.to_string(),
register,
});
let (host, load_failures) = ExtensionHost::load(
allowed,
host_info,
workspace,
UI_MODE,
vec!["confirm".to_string()],
delegate,
)
.await;
for (name, err) in &load_failures {
tracing::warn!(%name, %err, "sep: extension failed to load");
}
if host.is_empty() {
return None;
}
tracing::info!(count = host.len(), extensions = ?host.names(), "sep: attached extension host to the turn");
Some(ExtensionTurn {
host: Arc::new(host),
session_id: session_id.to_string(),
clear,
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
fn test_register(
slot: Arc<Mutex<Option<UnboundedSender<HumanResponse>>>>,
) -> RegisterConfirmation {
Arc::new(move |_sid: &str, responder| {
*slot.lock().unwrap() = Some(responder);
})
}
fn provider(
sink: UnboundedSender<Value>,
slot: Arc<Mutex<Option<UnboundedSender<HumanResponse>>>>,
) -> ConfirmUiProvider {
ConfirmUiProvider {
sink,
request_id: "req-1".into(),
session_id: "sess-1".into(),
register: test_register(slot),
}
}
#[test]
fn allowlist_parses_csv_and_denies_by_default() {
assert!(parse_allowlist(None).is_empty(), "unset ⇒ deny all");
assert!(parse_allowlist(Some("")).is_empty(), "blank ⇒ deny all");
assert!(
parse_allowlist(Some(" , ,")).is_empty(),
"only separators ⇒ deny all"
);
assert_eq!(parse_allowlist(Some("todo")), vec!["todo".to_string()]);
assert_eq!(
parse_allowlist(Some(" todo , gate ")),
vec!["todo".to_string(), "gate".to_string()]
);
}
#[test]
fn extension_allowed_intersects_server_allowlist_with_per_agent_ids() {
let allow = vec!["a".to_string(), "b".to_string()];
assert!(extension_allowed("a", &allow, None));
assert!(extension_allowed("b", &allow, None));
assert!(
!extension_allowed("c", &allow, None),
"not in server allowlist"
);
let only_a = vec!["a".to_string()];
assert!(extension_allowed("a", &allow, Some(&only_a)));
assert!(
!extension_allowed("b", &allow, Some(&only_a)),
"b is allowed by server but NOT enabled per-agent"
);
let none_enabled: Vec<String> = vec![];
assert!(!extension_allowed("a", &allow, Some(&none_enabled)));
assert!(!extension_allowed("b", &allow, Some(&none_enabled)));
let wants_c = vec!["c".to_string()];
assert!(!extension_allowed("c", &allow, Some(&wants_c)));
}
#[tokio::test]
async fn confirm_emits_frame_and_resolves_on_approval() {
let (sink_tx, mut sink_rx) = unbounded_channel::<Value>();
let slot = Arc::new(Mutex::new(None));
let p = provider(sink_tx, slot.clone());
let params = json!({ "kind": "confirm", "prompt": "Delete file?" });
let fut = tokio::spawn(async move { p.ui_request("todo", params).await });
let frame = sink_rx.recv().await.expect("frame");
assert_eq!(frame["type"], "write_confirmation_required");
assert_eq!(frame["data"]["data"]["toolId"], "todo");
assert_eq!(frame["data"]["data"]["actionDescription"], "Delete file?");
let responder = slot.lock().unwrap().take().expect("responder registered");
responder.send(HumanResponse::Approved).unwrap();
let result = fut.await.unwrap().unwrap();
assert_eq!(result, json!({ "confirmed": true }));
}
#[tokio::test]
async fn confirm_resolves_false_on_denial() {
let (sink_tx, mut sink_rx) = unbounded_channel::<Value>();
let slot = Arc::new(Mutex::new(None));
let p = provider(sink_tx, slot.clone());
let params = json!({ "kind": "confirm", "prompt": "Proceed?" });
let fut = tokio::spawn(async move { p.ui_request("gate", params).await });
let _ = sink_rx.recv().await.expect("frame");
let responder = slot.lock().unwrap().take().expect("responder");
responder
.send(HumanResponse::Denied {
reason: "no".into(),
})
.unwrap();
assert_eq!(fut.await.unwrap().unwrap(), json!({ "confirmed": false }));
}
#[tokio::test]
async fn confirm_cancels_when_turn_ends() {
let (sink_tx, mut sink_rx) = unbounded_channel::<Value>();
let slot = Arc::new(Mutex::new(None));
let p = provider(sink_tx, slot.clone());
let params = json!({ "kind": "confirm", "prompt": "Go?" });
let fut = tokio::spawn(async move { p.ui_request("x", params).await });
let _ = sink_rx.recv().await.expect("frame");
drop(slot.lock().unwrap().take());
assert_eq!(fut.await.unwrap().unwrap(), json!({ "cancelled": true }));
}
#[tokio::test]
async fn render_only_kinds_accept_and_drop() {
let (sink_tx, _sink_rx) = unbounded_channel::<Value>();
let slot = Arc::new(Mutex::new(None));
let p = provider(sink_tx, slot);
for kind in ["notify", "set_status", "set_widget", "set_title"] {
let params =
json!({ "kind": kind, "message": "hi", "status": "s", "widget": {}, "title": "t" });
assert_eq!(
p.ui_request("x", params).await.unwrap(),
json!({}),
"kind {kind}"
);
}
}
#[tokio::test]
async fn unsupported_interactive_kinds_cancel() {
let (sink_tx, _sink_rx) = unbounded_channel::<Value>();
let slot = Arc::new(Mutex::new(None));
let p = provider(sink_tx, slot);
for kind in ["select", "input"] {
let params = json!({ "kind": kind, "prompt": "?", "options": ["a"] });
assert_eq!(
p.ui_request("x", params).await.unwrap(),
json!({ "cancelled": true }),
"kind {kind}"
);
}
}
}