use crate::error::{LspError, Result};
use crate::transport::Transport;
use crate::types::{
ClientCapabilities, ClientInfo, Id, InitializeParams, InitializeResult, NotificationMessage,
RequestMessage, ResponseMessage, RpcMessage,
};
use std::collections::HashMap;
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::{mpsc, oneshot, RwLock};
struct PendingRequest {
sender: oneshot::Sender<ResponseMessage>,
}
pub struct Client<R, W> {
transport: Arc<RwLock<Transport<R, W>>>,
request_id_counter: AtomicI64,
pending_requests: Arc<RwLock<HashMap<Id, PendingRequest>>>,
message_receiver: Option<mpsc::UnboundedReceiver<RpcMessage>>,
#[allow(dead_code)]
message_sender: mpsc::UnboundedSender<RpcMessage>,
_message_task: tokio::task::JoinHandle<()>,
}
impl<R, W> Client<R, W>
where
R: AsyncRead + Unpin + Send + Sync + 'static,
W: AsyncWrite + Unpin + Send + Sync + 'static,
{
pub fn new(reader: R, writer: W) -> Self {
let transport = Arc::new(RwLock::new(Transport::new(reader, writer)));
let pending_requests: Arc<RwLock<HashMap<Id, PendingRequest>>> =
Arc::new(RwLock::new(HashMap::new()));
let (message_sender, message_receiver) = mpsc::unbounded_channel::<RpcMessage>();
let message_sender_clone = message_sender.clone();
let transport_clone = Arc::clone(&transport);
let pending_requests_clone = Arc::clone(&pending_requests);
let message_task = tokio::spawn(async move {
loop {
let message = {
let mut transport = transport_clone.write().await;
match transport.read_message().await {
Ok(msg) => msg,
Err(e) => {
log::error!("Failed to read message: {}", e);
break;
}
}
};
let rpc_message = match message.parse_rpc_message() {
Ok(msg) => msg,
Err(e) => {
log::error!("Failed to parse RPC message: {}", e);
continue;
}
};
match &rpc_message {
RpcMessage::Response(response) => {
if let Some(id) = &response.id {
let mut pending = pending_requests_clone.write().await;
if let Some(pending_request) = pending.remove(id) {
if let Err(e) = pending_request.sender.send(response.clone()) {
log::warn!(
"Failed to send response to pending request: {:?}",
e
);
}
} else {
log::warn!("Received response for unknown request ID: {}", id);
}
}
}
RpcMessage::Request(_) | RpcMessage::Notification(_) => {
if message_sender_clone.send(rpc_message).is_err() {
log::error!("Message receiver dropped, stopping message loop");
break;
}
}
}
}
});
Self {
transport,
request_id_counter: AtomicI64::new(1),
pending_requests,
message_receiver: Some(message_receiver),
message_sender,
_message_task: message_task,
}
}
fn next_request_id(&self) -> Id {
Id::Number(self.request_id_counter.fetch_add(1, Ordering::SeqCst))
}
pub async fn send_request(
&self,
method: impl Into<String>,
params: Option<serde_json::Value>,
) -> Result<ResponseMessage> {
let id = self.next_request_id();
let request = match params {
Some(params) => RequestMessage::with_params(id.clone(), method, params),
None => RequestMessage::new(id.clone(), method),
};
let (response_sender, response_receiver) = oneshot::channel();
{
let mut pending = self.pending_requests.write().await;
pending.insert(
id.clone(),
PendingRequest {
sender: response_sender,
},
);
}
{
let mut transport = self.transport.write().await;
transport
.write_rpc_message(&RpcMessage::Request(request))
.await?;
}
match response_receiver.await {
Ok(response) => Ok(response),
Err(_) => {
let mut pending = self.pending_requests.write().await;
pending.remove(&id);
Err(LspError::Other("Response receiver dropped".to_string()))
}
}
}
pub async fn send_notification(
&self,
method: impl Into<String>,
params: Option<serde_json::Value>,
) -> Result<()> {
let notification = match params {
Some(params) => NotificationMessage::with_params(method, params),
None => NotificationMessage::new(method),
};
let mut transport = self.transport.write().await;
transport
.write_rpc_message(&RpcMessage::Notification(notification))
.await
}
pub async fn receive_message(&mut self) -> Option<RpcMessage> {
if let Some(ref mut receiver) = self.message_receiver {
receiver.recv().await
} else {
None
}
}
pub async fn send_response(
&self,
id: Id,
result: Option<serde_json::Value>,
error: Option<crate::error::ResponseError>,
) -> Result<()> {
let response = if let Some(error) = error {
ResponseMessage::error(Some(id), error)
} else {
ResponseMessage::success(id, result.unwrap_or(serde_json::Value::Null))
};
let mut transport = self.transport.write().await;
transport
.write_rpc_message(&RpcMessage::Response(response))
.await
}
pub async fn has_pending_requests(&self) -> bool {
!self.pending_requests.read().await.is_empty()
}
pub async fn pending_request_count(&self) -> usize {
self.pending_requests.read().await.len()
}
pub async fn cancel_all_requests(&self) {
let mut pending = self.pending_requests.write().await;
pending.clear();
}
pub async fn initialize(&self, params: InitializeParams) -> Result<InitializeResult> {
let response = self
.send_request("initialize", Some(serde_json::to_value(params)?))
.await?;
if let Some(error) = response.error {
return Err(LspError::InitializationFailed(format!(
"Initialize request failed: {}",
error.message
)));
}
if let Some(result) = response.result {
Ok(serde_json::from_value(result)?)
} else {
Err(LspError::InitializationFailed(
"Initialize response missing result".to_string(),
))
}
}
pub async fn initialized(&self) -> Result<()> {
self.send_notification("initialized", Some(serde_json::json!({})))
.await
}
pub async fn initialize_default(
&self,
client_name: impl Into<String>,
client_version: Option<String>,
root_uri: Option<String>,
) -> Result<InitializeResult> {
let params = InitializeParams {
process_id: Some(std::process::id()),
client_info: Some(ClientInfo {
name: client_name.into(),
version: client_version,
}),
locale: None,
root_path: None,
root_uri,
initialization_options: None,
capabilities: ClientCapabilities::default(),
trace: None,
workspace_folders: None,
};
let result = self.initialize(params).await?;
self.initialized().await?;
Ok(result)
}
}
impl<R, W> Drop for Client<R, W> {
fn drop(&mut self) {
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::io::Cursor;
#[tokio::test]
async fn test_client_creation() {
let reader = Cursor::new(vec![]);
let writer = Cursor::new(Vec::new());
let client = Client::new(reader, writer);
assert_eq!(client.pending_request_count().await, 0);
assert!(!client.has_pending_requests().await);
}
#[tokio::test]
async fn test_request_id_generation() {
let reader = Cursor::new(vec![]);
let writer = Cursor::new(Vec::new());
let client = Client::new(reader, writer);
let id1 = client.next_request_id();
let id2 = client.next_request_id();
assert_ne!(id1, id2);
if let (Id::Number(n1), Id::Number(n2)) = (id1, id2) {
assert!(n1 < n2);
}
}
#[tokio::test]
async fn test_send_notification() {
let reader = Cursor::new(vec![]);
let writer = Cursor::new(Vec::new());
let client = Client::new(reader, writer);
let result = client
.send_notification("test/notification", Some(json!({"key": "value"})))
.await;
assert!(result.is_ok());
let result2 = client.send_notification("test/simple", None).await;
assert!(result2.is_ok());
}
#[test]
fn test_pending_request_struct() {
let (sender, _receiver) = oneshot::channel();
let _pending = PendingRequest { sender };
}
}