1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
//! Language Server lifecycle.
//!
//! *Only applies to Language Servers.*
//!
//! This middleware handles
//! [the lifecycle of Language Servers](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#lifeCycleMessages),
//! specifically:
//! - Exit the main loop with `ControlFlow::Break(Ok(()))` on `exit` notification.
//! - Responds unrelated requests with errors and ignore unrelated notifications during
//!   initialization and shutting down.
use std::future::{ready, Future, Ready};
use std::ops::ControlFlow;
use std::pin::Pin;
use std::task::{Context, Poll};

use futures::future::Either;
use lsp_types::notification::{self, Notification};
use lsp_types::request::{self, Request};
use pin_project_lite::pin_project;
use tower_layer::Layer;
use tower_service::Service;

use crate::{
    AnyEvent, AnyNotification, AnyRequest, Error, ErrorCode, LspService, ResponseError, Result,
};

#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
enum State {
    #[default]
    Uninitialized,
    Initializing,
    Ready,
    ShuttingDown,
}

/// The middleware handling Language Server lifecycle.
///
/// See [module level documentations](self) for details.
#[derive(Debug, Default)]
pub struct Lifecycle<S> {
    service: S,
    state: State,
}

define_getters!(impl[S] Lifecycle<S>, service: S);

impl<S> Lifecycle<S> {
    /// Creating the `Lifecycle` middleware in uninitialized state.
    #[must_use]
    pub fn new(service: S) -> Self {
        Self {
            service,
            state: State::Uninitialized,
        }
    }
}

impl<S: LspService> Service<AnyRequest> for Lifecycle<S>
where
    S::Error: From<ResponseError>,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = ResponseFuture<S::Future>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.service.poll_ready(cx)
    }

    fn call(&mut self, req: AnyRequest) -> Self::Future {
        let inner = match (self.state, &*req.method) {
            (State::Uninitialized, request::Initialize::METHOD) => {
                self.state = State::Initializing;
                Either::Left(self.service.call(req))
            }
            (State::Uninitialized | State::Initializing, _) => {
                Either::Right(ready(Err(ResponseError {
                    code: ErrorCode::SERVER_NOT_INITIALIZED,
                    message: "Server is not initialized yet".into(),
                    data: None,
                }
                .into())))
            }
            (_, request::Initialize::METHOD) => Either::Right(ready(Err(ResponseError {
                code: ErrorCode::INVALID_REQUEST,
                message: "Server is already initialized".into(),
                data: None,
            }
            .into()))),
            (State::Ready, _) => {
                if req.method == request::Shutdown::METHOD {
                    self.state = State::ShuttingDown;
                }
                Either::Left(self.service.call(req))
            }
            (State::ShuttingDown, _) => Either::Right(ready(Err(ResponseError {
                code: ErrorCode::INVALID_REQUEST,
                message: "Server is shutting down".into(),
                data: None,
            }
            .into()))),
        };
        ResponseFuture { inner }
    }
}

impl<S: LspService> LspService for Lifecycle<S>
where
    S::Error: From<ResponseError>,
{
    fn notify(&mut self, notif: AnyNotification) -> ControlFlow<Result<()>> {
        match &*notif.method {
            notification::Exit::METHOD => {
                self.service.notify(notif)?;
                ControlFlow::Break(Ok(()))
            }
            notification::Initialized::METHOD => {
                if self.state != State::Initializing {
                    return ControlFlow::Break(Err(Error::Protocol(format!(
                        "Unexpected initialized notification on state {:?}",
                        self.state
                    ))));
                }
                self.state = State::Ready;
                self.service.notify(notif)?;
                ControlFlow::Continue(())
            }
            _ => self.service.notify(notif),
        }
    }

    fn emit(&mut self, event: AnyEvent) -> ControlFlow<Result<()>> {
        self.service.emit(event)
    }
}

pin_project! {
    /// The [`Future`] type used by the [`Lifecycle`] middleware.
    pub struct ResponseFuture<Fut: Future> {
        #[pin]
        inner: Either<Fut, Ready<Fut::Output>>,
    }
}

impl<Fut: Future> Future for ResponseFuture<Fut> {
    type Output = Fut::Output;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        self.project().inner.poll(cx)
    }
}

/// A [`tower_layer::Layer`] which builds [`Lifecycle`].
#[must_use]
#[derive(Clone, Default)]
pub struct LifecycleLayer {
    _private: (),
}

impl<S> Layer<S> for LifecycleLayer {
    type Service = Lifecycle<S>;

    fn layer(&self, inner: S) -> Self::Service {
        Lifecycle::new(inner)
    }
}