ra_ap_proc_macro_api 0.0.324

RPC Api for the `proc-macro-srv` crate of rust-analyzer.
Documentation
//! Bidirectional protocol methods

use std::{
    io::{self, BufRead, Write},
    panic::{AssertUnwindSafe, catch_unwind},
    sync::Arc,
};

use paths::AbsPath;
use span::Span;

use crate::{
    ProcMacro, ProcMacroKind, ServerError,
    bidirectional_protocol::msg::{
        BidirectionalMessage, ExpandMacro, ExpandMacroData, ExpnGlobals, Request, Response,
        SubRequest, SubResponse,
    },
    legacy_protocol::{
        SpanMode,
        msg::{
            FlatTree, ServerConfig, SpanDataIndexMap, deserialize_span_data_index_map,
            serialize_span_data_index_map,
        },
    },
    process::ProcMacroServerProcess,
    transport::postcard,
};

pub mod msg;

pub type SubCallback<'a> = &'a dyn Fn(SubRequest) -> Result<SubResponse, ServerError>;

pub fn run_conversation(
    writer: &mut dyn Write,
    reader: &mut dyn BufRead,
    buf: &mut Vec<u8>,
    msg: BidirectionalMessage,
    callback: SubCallback<'_>,
) -> Result<BidirectionalMessage, ServerError> {
    let encoded = postcard::encode(&msg).map_err(wrap_encode)?;
    postcard::write(writer, &encoded).map_err(wrap_io("failed to write initial request"))?;

    loop {
        let maybe_buf = postcard::read(reader, buf).map_err(wrap_io("failed to read message"))?;
        let Some(b) = maybe_buf else {
            return Err(ServerError {
                message: "proc-macro server closed the stream".into(),
                io: Some(Arc::new(io::Error::new(io::ErrorKind::UnexpectedEof, "closed"))),
            });
        };

        let msg: BidirectionalMessage = postcard::decode(b).map_err(wrap_decode)?;

        match msg {
            BidirectionalMessage::Response(response) => {
                return Ok(BidirectionalMessage::Response(response));
            }
            BidirectionalMessage::SubRequest(sr) => {
                // TODO: Avoid `AssertUnwindSafe` by making the callback `UnwindSafe` once `ExpandDatabase`
                // becomes unwind-safe (currently blocked by `parking_lot::RwLock` in the VFS).
                let resp = match catch_unwind(AssertUnwindSafe(|| callback(sr))) {
                    Ok(Ok(resp)) => BidirectionalMessage::SubResponse(resp),
                    Ok(Err(err)) => BidirectionalMessage::SubResponse(SubResponse::Cancel {
                        reason: err.to_string(),
                    }),
                    Err(_) => BidirectionalMessage::SubResponse(SubResponse::Cancel {
                        reason: "callback panicked or was cancelled".into(),
                    }),
                };

                let encoded = postcard::encode(&resp).map_err(wrap_encode)?;
                postcard::write(writer, &encoded)
                    .map_err(wrap_io("failed to write sub-response"))?;
            }
            _ => {
                return Err(ServerError {
                    message: format!("unexpected message {:?}", msg),
                    io: None,
                });
            }
        }
    }
}

fn wrap_io(msg: &'static str) -> impl Fn(io::Error) -> ServerError {
    move |err| ServerError { message: msg.into(), io: Some(Arc::new(err)) }
}

fn wrap_encode(err: io::Error) -> ServerError {
    ServerError { message: "failed to encode message".into(), io: Some(Arc::new(err)) }
}

fn wrap_decode(err: io::Error) -> ServerError {
    ServerError { message: "failed to decode message".into(), io: Some(Arc::new(err)) }
}

pub(crate) fn version_check(
    srv: &ProcMacroServerProcess,
    callback: SubCallback<'_>,
) -> Result<u32, ServerError> {
    let request = BidirectionalMessage::Request(Request::ApiVersionCheck {});

    let response_payload = run_request(srv, request, callback)?;

    match response_payload {
        BidirectionalMessage::Response(Response::ApiVersionCheck(version)) => Ok(version),
        other => {
            Err(ServerError { message: format!("unexpected response: {:?}", other), io: None })
        }
    }
}

/// Enable support for rust-analyzer span mode if the server supports it.
pub(crate) fn enable_rust_analyzer_spans(
    srv: &ProcMacroServerProcess,
    callback: SubCallback<'_>,
) -> Result<SpanMode, ServerError> {
    let request = BidirectionalMessage::Request(Request::SetConfig(ServerConfig {
        span_mode: SpanMode::RustAnalyzer,
    }));

    let response_payload = run_request(srv, request, callback)?;

    match response_payload {
        BidirectionalMessage::Response(Response::SetConfig(ServerConfig { span_mode })) => {
            Ok(span_mode)
        }
        _ => Err(ServerError { message: "unexpected response".to_owned(), io: None }),
    }
}

/// Finds proc-macros in a given dynamic library.
pub(crate) fn find_proc_macros(
    srv: &ProcMacroServerProcess,
    dylib_path: &AbsPath,
    callback: SubCallback<'_>,
) -> Result<Result<Vec<(String, ProcMacroKind)>, String>, ServerError> {
    let request = BidirectionalMessage::Request(Request::ListMacros {
        dylib_path: dylib_path.to_path_buf().into(),
    });

    let response_payload = run_request(srv, request, callback)?;

    match response_payload {
        BidirectionalMessage::Response(Response::ListMacros(it)) => Ok(it),
        _ => Err(ServerError { message: "unexpected response".to_owned(), io: None }),
    }
}

pub(crate) fn expand(
    proc_macro: &ProcMacro,
    process: &ProcMacroServerProcess,
    subtree: tt::SubtreeView<'_>,
    attr: Option<tt::SubtreeView<'_>>,
    env: Vec<(String, String)>,
    def_site: Span,
    call_site: Span,
    mixed_site: Span,
    current_dir: String,
    callback: SubCallback<'_>,
) -> Result<Result<tt::TopSubtree, String>, crate::ServerError> {
    let version = process.version();
    let mut span_data_table = SpanDataIndexMap::default();
    let def_site = span_data_table.insert_full(def_site).0;
    let call_site = span_data_table.insert_full(call_site).0;
    let mixed_site = span_data_table.insert_full(mixed_site).0;
    let task = BidirectionalMessage::Request(Request::ExpandMacro(Box::new(ExpandMacro {
        data: ExpandMacroData {
            macro_body: FlatTree::from_subtree(subtree, version, &mut span_data_table),
            macro_name: proc_macro.name.to_string(),
            attributes: attr
                .map(|subtree| FlatTree::from_subtree(subtree, version, &mut span_data_table)),
            has_global_spans: ExpnGlobals { def_site, call_site, mixed_site },
            span_data_table: if process.rust_analyzer_spans() {
                serialize_span_data_index_map(&span_data_table)
            } else {
                Vec::new()
            },
        },
        lib: proc_macro.dylib_path.to_path_buf().into(),
        env,
        current_dir: Some(current_dir),
    })));

    let response_payload = run_request(process, task, callback)?;

    match response_payload {
        BidirectionalMessage::Response(Response::ExpandMacro(it)) => Ok(it
            .map(|tree| {
                let mut expanded = FlatTree::to_subtree_resolved(tree, version, &span_data_table);
                if proc_macro.needs_fixup_change() {
                    proc_macro.change_fixup_to_match_old_server(&mut expanded);
                }
                expanded
            })
            .map_err(|msg| msg.0)),
        BidirectionalMessage::Response(Response::ExpandMacroExtended(it)) => Ok(it
            .map(|resp| {
                let mut expanded = FlatTree::to_subtree_resolved(
                    resp.tree,
                    version,
                    &deserialize_span_data_index_map(&resp.span_data_table),
                );
                if proc_macro.needs_fixup_change() {
                    proc_macro.change_fixup_to_match_old_server(&mut expanded);
                }
                expanded
            })
            .map_err(|msg| msg.0)),
        _ => Err(ServerError { message: "unexpected response".to_owned(), io: None }),
    }
}

fn run_request(
    srv: &ProcMacroServerProcess,
    msg: BidirectionalMessage,
    callback: SubCallback<'_>,
) -> Result<BidirectionalMessage, ServerError> {
    if let Some(err) = srv.exited() {
        return Err(err.clone());
    }
    srv.run_bidirectional(msg, callback)
}

pub fn reject_subrequests(req: SubRequest) -> Result<SubResponse, ServerError> {
    Err(ServerError { message: format!("{req:?} sub-request not supported here"), io: None })
}