use {
crate::{
connect::{
ipc::Connection,
lsp::{
ClientId,
ClientRegistry,
},
},
database::{
PartitionWriteContextRef,
Partitions,
},
protocol::{
ipc::{
Endpoint,
Handshake,
IpcListener,
framing,
},
jsonrpc::Message,
},
scheduler::task::TaskContext,
},
std::{
io,
sync::Arc,
},
};
pub type DisconnectCallback = Arc<dyn Fn(ClientId) + Send + Sync>;
pub struct IpcServer {
listener: IpcListener,
registry: Arc<ClientRegistry>,
version: String,
disconnect_callback: Option<DisconnectCallback>,
}
impl IpcServer {
pub async fn bind(
endpoint: &Endpoint,
registry: Arc<ClientRegistry>,
version: impl Into<String>,
) -> io::Result<Self> {
let listener = IpcListener::bind(endpoint).await?;
Ok(Self {
listener,
registry,
version: version.into(),
disconnect_callback: None,
})
}
pub fn set_disconnect_callback(&mut self, callback: DisconnectCallback) {
self.disconnect_callback = Some(callback);
}
pub async fn accept_client(&mut self) -> io::Result<ClientId> {
let stream = self.listener.accept().await?;
let (mut reader, mut writer) = stream.into_split();
let mut client_hs = framing::recv_handshake(&mut reader).await?;
let server_hs = Handshake::new(&self.version);
if !server_hs.is_compatible(&client_hs) {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"version mismatch: server={}, client={}",
server_hs.version, client_hs.version
),
));
}
framing::send_handshake(&mut writer, &server_hs).await?;
let (client_sender, to_ipc) = async_channel::unbounded::<Message>();
let (from_ipc, client_receiver) = async_channel::unbounded::<Message>();
let connection = Connection {
sender: client_sender,
receiver: client_receiver,
};
let kind = client_hs.client_kind();
let metadata = client_hs.take_metadata();
let client_id = self.registry.register(kind, connection, metadata);
smol::spawn(async move {
while let Ok(msg) = to_ipc.recv().await {
if framing::send_message(&mut writer, &msg).await.is_err() {
break;
}
}
})
.detach();
let registry = self.registry.clone();
let id = client_id;
let disconnect_cb = self.disconnect_callback.clone();
smol::spawn(async move {
while let Ok(Some(msg)) = framing::recv_message(&mut reader).await {
if from_ipc.send(msg).await.is_err() {
break;
}
}
registry.unregister(id);
if let Some(cb) = disconnect_cb {
cb(id);
}
})
.detach();
Ok(client_id)
}
pub async fn accept_loop(mut self) {
loop {
match self.accept_client().await {
| Ok(client_id) => {
otel::event!(
"ipc_client_connected",
"client_id" = client_id.as_u64() as i64
);
},
| Err(e) => {
otel::error!("ipc_accept_error", format!("IPC accept error: {}", e));
},
}
}
}
pub fn registry(&self) -> &Arc<ClientRegistry> {
&self.registry
}
}
pub struct LaburnumLanguageServer;
impl crate::protocol::lsp::SourceFileMeta for LaburnumLanguageServer {}
impl<P: Partitions>
crate::scheduler::key_watcher::KeyWatcher<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
fn dispatch_watcher<F>(
_pk: crate::Ident,
_updated_sks: Vec<String>,
_deleted_sks: Vec<String>,
_spawn: F,
) where
F: Fn(
crate::Ident,
Vec<String>,
Vec<String>,
for<'a> fn(
&'a mut TaskContext<P, LaburnumLanguageServer>,
&'a mut PartitionWriteContextRef<'a, P>,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = crate::scheduler::key_watcher::WatcherResult<
P,
LaburnumLanguageServer,
>,
> + Send
+ 'a,
>,
>,
),
{
}
}
impl<P: Partitions>
crate::protocol::lsp::InitializeService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::MonikerService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::HoverService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::GotoService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::CompletionService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::DocumentColorService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::SignatureHelpService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::CodeLensService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::WorkspaceDiagnosticService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::TextDocumentService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::TextDocumentHookService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::SymbolService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::SemanticTokensService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::TypeHierarchyService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::InlineValueService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::InlayHintService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::CallHierarchyService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::CodeActionService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::DocumentSymbolService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::ExecuteCommandService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::DocumentLinkService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::RenameService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::NotebookDocumentService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::ReferenceService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::SelectionRangeService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::DiagnosticService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::LinkedEditingService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::FormattingService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::WorkspaceService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::WorkspaceHooksService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::FoldingRangeService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions>
crate::protocol::lsp::DocumentHighlightService<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
impl<P: Partitions> crate::hooks::LaburnumHooks<P, LaburnumLanguageServer>
for LaburnumLanguageServer
{
}
#[cfg(test)]
mod tests {
use {
super::*,
crate::connect::lsp::ClientKind,
crate::protocol::ipc::MemoryTransport,
};
#[test]
fn ipc_server_accept_client() {
smol::block_on(async {
let transport = MemoryTransport::new();
let endpoint = Endpoint::memory(transport.clone(), "test-server");
let registry = Arc::new(ClientRegistry::new());
let mut server = IpcServer::bind(&endpoint, registry.clone(), "v1.0.0")
.await
.unwrap();
let client_endpoint = Endpoint::memory(transport.clone(), "test-server");
let client_task = smol::spawn(async move {
let stream = crate::protocol::ipc::IpcStream::connect(&client_endpoint)
.await
.unwrap();
let (mut reader, mut writer) = stream.into_split();
let client_hs = Handshake::new("v1.0.0");
framing::send_handshake(&mut writer, &client_hs)
.await
.unwrap();
let _server_hs = framing::recv_handshake(&mut reader).await.unwrap();
(reader, writer)
});
let client_id = server.accept_client().await.unwrap();
let (_reader, _writer) = client_task.await;
assert_eq!(registry.client_count(), 1);
assert!(registry.get(client_id).is_some());
let client = registry.get(client_id).unwrap();
assert_eq!(client.kind(), ClientKind::Cli);
});
}
#[test]
fn ipc_server_version_mismatch() {
smol::block_on(async {
let transport = MemoryTransport::new();
let endpoint = Endpoint::memory(transport.clone(), "mismatch-test");
let registry = Arc::new(ClientRegistry::new());
let mut server = IpcServer::bind(&endpoint, registry.clone(), "v2.0.0")
.await
.unwrap();
let client_endpoint =
Endpoint::memory(transport.clone(), "mismatch-test");
smol::spawn(async move {
let stream = crate::protocol::ipc::IpcStream::connect(&client_endpoint)
.await
.unwrap();
let (mut reader, mut writer) = stream.into_split();
let client_hs = Handshake::new("v1.0.0");
framing::send_handshake(&mut writer, &client_hs)
.await
.unwrap();
let _ = framing::recv_handshake(&mut reader).await;
})
.detach();
let result = server.accept_client().await;
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("version mismatch"));
});
}
#[test]
fn ipc_server_multiple_clients() {
use async_channel::unbounded;
smol::block_on(async {
let transport = MemoryTransport::new();
let endpoint = Endpoint::memory(transport.clone(), "multi-client");
let registry = Arc::new(ClientRegistry::new());
let mut server = IpcServer::bind(&endpoint, registry.clone(), "v1.0.0")
.await
.unwrap();
let (done_tx, done_rx) = unbounded::<()>();
for _ in 0..3 {
let t = transport.clone();
let client_endpoint = Endpoint::memory(t, "multi-client");
let rx = done_rx.clone();
smol::spawn(async move {
let stream =
crate::protocol::ipc::IpcStream::connect(&client_endpoint)
.await
.unwrap();
let (mut reader, mut writer) = stream.into_split();
let client_hs = Handshake::new("v1.0.0");
framing::send_handshake(&mut writer, &client_hs)
.await
.unwrap();
let _server_hs = framing::recv_handshake(&mut reader).await.unwrap();
let _ = rx.recv().await;
})
.detach();
let _client_id = server.accept_client().await.unwrap();
}
assert_eq!(registry.client_count(), 3);
drop(done_tx);
});
}
}