use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use serde_json::Value;
use tokio::sync::{mpsc, Mutex};
use crate::jsonrpc::JsonRpcClient;
use crate::types::*;
use crate::CopilotError;
pub type ToolHandler = Arc<
dyn Fn(
Value,
ToolInvocation,
) -> Pin<Box<dyn Future<Output = Result<Value, CopilotError>> + Send>>
+ Send
+ Sync,
>;
pub type PermissionHandlerFn = Arc<
dyn Fn(
PermissionRequest,
String,
) -> Pin<Box<dyn Future<Output = Result<PermissionRequestResult, CopilotError>> + Send>>
+ Send
+ Sync,
>;
pub type UserInputHandlerFn = Arc<
dyn Fn(
UserInputRequest,
String,
) -> Pin<Box<dyn Future<Output = Result<UserInputResponse, CopilotError>> + Send>>
+ Send
+ Sync,
>;
pub type HooksHandlerFn = Arc<
dyn Fn(
String,
Value,
String,
) -> Pin<Box<dyn Future<Output = Result<Option<Value>, CopilotError>> + Send>>
+ Send
+ Sync,
>;
pub type SessionEventHandlerFn = Arc<dyn Fn(SessionEvent) + Send + Sync>;
pub type TypedSessionEventHandlerFn = Arc<dyn Fn(SessionEvent) + Send + Sync>;
pub struct Subscription {
unsubscribe_fn: Option<Box<dyn FnOnce() + Send>>,
}
impl Subscription {
fn new(f: impl FnOnce() + Send + 'static) -> Self {
Self {
unsubscribe_fn: Some(Box::new(f)),
}
}
pub fn unsubscribe(mut self) {
if let Some(f) = self.unsubscribe_fn.take() {
f();
}
}
}
impl Drop for Subscription {
fn drop(&mut self) {
if let Some(f) = self.unsubscribe_fn.take() {
f();
}
}
}
pub struct CopilotSession {
session_id: String,
workspace_path: Option<String>,
rpc_client: Arc<JsonRpcClient>,
tool_handlers: Arc<Mutex<HashMap<String, ToolHandler>>>,
permission_handler: Arc<Mutex<Option<PermissionHandlerFn>>>,
user_input_handler: Arc<Mutex<Option<UserInputHandlerFn>>>,
hooks_handler: Arc<Mutex<Option<HooksHandlerFn>>>,
event_handlers: Arc<Mutex<Vec<(u64, SessionEventHandlerFn)>>>,
typed_event_handlers: Arc<Mutex<HashMap<String, Vec<(u64, TypedSessionEventHandlerFn)>>>>,
next_handler_id: Arc<Mutex<u64>>,
}
impl CopilotSession {
pub(crate) fn new(
session_id: String,
rpc_client: Arc<JsonRpcClient>,
workspace_path: Option<String>,
) -> Self {
Self {
session_id,
workspace_path,
rpc_client,
tool_handlers: Arc::new(Mutex::new(HashMap::new())),
permission_handler: Arc::new(Mutex::new(None)),
user_input_handler: Arc::new(Mutex::new(None)),
hooks_handler: Arc::new(Mutex::new(None)),
event_handlers: Arc::new(Mutex::new(Vec::new())),
typed_event_handlers: Arc::new(Mutex::new(HashMap::new())),
next_handler_id: Arc::new(Mutex::new(0)),
}
}
pub fn session_id(&self) -> &str {
&self.session_id
}
pub fn workspace_path(&self) -> Option<&str> {
self.workspace_path.as_deref()
}
pub async fn send(&self, options: MessageOptions) -> Result<String, CopilotError> {
let params = serde_json::json!({
"sessionId": self.session_id,
"prompt": options.prompt,
"attachments": options.attachments,
"mode": options.mode,
"responseFormat": options.response_format,
"imageOptions": options.image_options,
});
let response = self.rpc_client.request("session.send", params, None).await?;
let message_id = response
.get("messageId")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
Ok(message_id)
}
pub async fn send_and_wait(
&self,
options: MessageOptions,
timeout: Option<u64>,
) -> Result<Option<SessionEvent>, CopilotError> {
let effective_timeout = timeout.unwrap_or(60_000);
let (idle_tx, mut idle_rx) = mpsc::channel::<Result<(), CopilotError>>(1);
let last_assistant_message: Arc<Mutex<Option<SessionEvent>>> =
Arc::new(Mutex::new(None));
let last_msg_clone = Arc::clone(&last_assistant_message);
let idle_tx_clone = idle_tx.clone();
let sub = self
.on(move |event: SessionEvent| {
if event.is_assistant_message() {
let mut msg = last_msg_clone.blocking_lock();
*msg = Some(event);
} else if event.is_session_idle() {
let _ = idle_tx_clone.try_send(Ok(()));
} else if event.is_session_error() {
let error_msg = event
.error_message()
.unwrap_or("Unknown error")
.to_string();
let _ = idle_tx_clone.try_send(Err(CopilotError::SessionError(error_msg)));
}
})
.await;
self.send(options).await?;
let result = tokio::time::timeout(
std::time::Duration::from_millis(effective_timeout),
idle_rx.recv(),
)
.await;
sub.unsubscribe();
match result {
Ok(Some(Ok(()))) => {
let msg = last_assistant_message.lock().await;
Ok(msg.clone())
}
Ok(Some(Err(e))) => Err(e),
Ok(None) => Err(CopilotError::ConnectionClosed),
Err(_) => Err(CopilotError::Timeout(effective_timeout)),
}
}
pub async fn on<F>(&self, handler: F) -> Subscription
where
F: Fn(SessionEvent) + Send + Sync + 'static,
{
let handler_id = {
let mut id = self.next_handler_id.lock().await;
*id += 1;
*id
};
let handler_arc: SessionEventHandlerFn = Arc::new(handler);
{
let mut handlers = self.event_handlers.lock().await;
handlers.push((handler_id, handler_arc));
}
let event_handlers = Arc::clone(&self.event_handlers);
Subscription::new(move || {
let mut handlers = event_handlers.blocking_lock();
handlers.retain(|(id, _)| *id != handler_id);
})
}
pub async fn on_event<F>(&self, event_type: &str, handler: F) -> Subscription
where
F: Fn(SessionEvent) + Send + Sync + 'static,
{
let handler_id = {
let mut id = self.next_handler_id.lock().await;
*id += 1;
*id
};
let handler_arc: TypedSessionEventHandlerFn = Arc::new(handler);
let event_type_str = event_type.to_string();
{
let mut handlers = self.typed_event_handlers.lock().await;
handlers
.entry(event_type_str.clone())
.or_default()
.push((handler_id, handler_arc));
}
let typed_handlers = Arc::clone(&self.typed_event_handlers);
let et = event_type_str;
Subscription::new(move || {
let mut handlers = typed_handlers.blocking_lock();
if let Some(list) = handlers.get_mut(&et) {
list.retain(|(id, _)| *id != handler_id);
}
})
}
pub(crate) async fn dispatch_event(&self, event: SessionEvent) {
{
let handlers = self.typed_event_handlers.lock().await;
if let Some(list) = handlers.get(&event.event_type) {
for (_, handler) in list {
handler(event.clone());
}
}
}
{
let handlers = self.event_handlers.lock().await;
for (_, handler) in handlers.iter() {
handler(event.clone());
}
}
}
pub async fn register_tool(&self, name: &str, handler: ToolHandler) {
let mut handlers = self.tool_handlers.lock().await;
handlers.insert(name.to_string(), handler);
}
pub async fn register_tools(&self, tools: Vec<(String, ToolHandler)>) {
let mut handlers = self.tool_handlers.lock().await;
handlers.clear();
for (name, handler) in tools {
handlers.insert(name, handler);
}
}
pub(crate) async fn get_tool_handler(&self, name: &str) -> Option<ToolHandler> {
let handlers = self.tool_handlers.lock().await;
handlers.get(name).cloned()
}
pub async fn register_permission_handler(&self, handler: PermissionHandlerFn) {
let mut h = self.permission_handler.lock().await;
*h = Some(handler);
}
pub(crate) async fn handle_permission_request(
&self,
request: Value,
) -> Result<PermissionRequestResult, CopilotError> {
let handler = self.permission_handler.lock().await;
if let Some(ref h) = *handler {
let perm_request: PermissionRequest = serde_json::from_value(request)
.map_err(|e| CopilotError::Serialization(e.to_string()))?;
h(perm_request, self.session_id.clone()).await
} else {
Ok(PermissionRequestResult {
kind: PermissionResultKind::DeniedNoApprovalRuleAndCouldNotRequestFromUser,
rules: None,
})
}
}
pub async fn register_user_input_handler(&self, handler: UserInputHandlerFn) {
let mut h = self.user_input_handler.lock().await;
*h = Some(handler);
}
pub(crate) async fn handle_user_input_request(
&self,
request: Value,
) -> Result<UserInputResponse, CopilotError> {
let handler = self.user_input_handler.lock().await;
if let Some(ref h) = *handler {
let input_request: UserInputRequest = serde_json::from_value(request)
.map_err(|e| CopilotError::Serialization(e.to_string()))?;
h(input_request, self.session_id.clone()).await
} else {
Err(CopilotError::NoHandler(
"User input requested but no handler registered".to_string(),
))
}
}
pub async fn register_hooks_handler(&self, handler: HooksHandlerFn) {
let mut h = self.hooks_handler.lock().await;
*h = Some(handler);
}
pub(crate) async fn handle_hooks_invoke(
&self,
hook_type: &str,
input: Value,
) -> Result<Option<Value>, CopilotError> {
let handler = self.hooks_handler.lock().await;
if let Some(ref h) = *handler {
h(hook_type.to_string(), input, self.session_id.clone()).await
} else {
Ok(None)
}
}
pub async fn get_messages(&self) -> Result<Vec<SessionEvent>, CopilotError> {
let params = serde_json::json!({ "sessionId": self.session_id });
let response = self
.rpc_client
.request("session.getMessages", params, None)
.await?;
let events: Vec<SessionEvent> = serde_json::from_value(
response
.get("events")
.cloned()
.unwrap_or(Value::Array(vec![])),
)
.map_err(|e| CopilotError::Serialization(e.to_string()))?;
Ok(events)
}
pub async fn get_metadata(&self) -> Result<Value, CopilotError> {
let params = serde_json::json!({ "sessionId": self.session_id });
let response = self
.rpc_client
.request("session.getMetadata", params, None)
.await?;
Ok(response)
}
pub async fn destroy(&self) -> Result<(), CopilotError> {
let params = serde_json::json!({ "sessionId": self.session_id });
self.rpc_client
.request("session.destroy", params, None)
.await?;
{
let mut handlers = self.event_handlers.lock().await;
handlers.clear();
}
{
let mut handlers = self.typed_event_handlers.lock().await;
handlers.clear();
}
{
let mut handlers = self.tool_handlers.lock().await;
handlers.clear();
}
{
let mut handler = self.permission_handler.lock().await;
*handler = None;
}
{
let mut handler = self.user_input_handler.lock().await;
*handler = None;
}
Ok(())
}
pub async fn abort(&self) -> Result<(), CopilotError> {
let params = serde_json::json!({ "sessionId": self.session_id });
self.rpc_client
.request("session.abort", params, None)
.await?;
Ok(())
}
}