Skip to main content

oak_rust/
lsp.rs

1use crate::{RustLanguage, parser::RustElementType};
2use core::range::Range;
3use dashmap::DashMap;
4use futures::{Future, FutureExt};
5use oak_core::{
6    GreenNode, Source,
7    language::{ElementType, TokenType},
8    parser::{ParseCache, Parser, session::ParseSession},
9    tree::RedNode,
10};
11use oak_hover::{Hover, HoverProvider};
12use oak_lsp::service::LanguageService;
13use oak_vfs::Vfs;
14
15/// Hover provider implementation for Rust.
16pub struct RustHoverProvider;
17
18impl HoverProvider<RustLanguage> for RustHoverProvider {
19    fn hover(&self, node: &RedNode<RustLanguage>, _range: Range<usize>) -> Option<Hover> {
20        let kind = node.green.kind;
21
22        // Provide context-aware hover information
23        let contents = match kind {
24            RustElementType::Function => "### Rust Function\nDefines a callable block of code.",
25            RustElementType::StructItem => "### Rust Struct\nDefines a custom data type.",
26            RustElementType::ModuleItem => "### Rust Module\nOrganizes code into namespaces.",
27            RustElementType::Trait => "### Rust Trait\nDefines a shared behavior.",
28            _ => return None,
29        };
30
31        Some(Hover { contents: contents.to_string(), range: Some(node.span()) })
32    }
33}
34
35/// Language service implementation for Rust.
36pub struct RustLanguageService<V: Vfs> {
37    vfs: V,
38    workspace: oak_lsp::workspace::WorkspaceManager,
39    hover_provider: RustHoverProvider,
40    sessions: DashMap<String, Box<ParseSession<RustLanguage>>>,
41}
42
43impl<V: Vfs> RustLanguageService<V> {
44    /// Creates a new `RustLanguageService`.
45    pub fn new(vfs: V) -> Self {
46        Self { vfs, workspace: oak_lsp::workspace::WorkspaceManager::default(), hover_provider: RustHoverProvider, sessions: DashMap::new() }
47    }
48
49    fn collect_definitions<S: Source + ?Sized>(&self, node: &RedNode<RustLanguage>, name: &str, source: &S, uri: &str, definitions: &mut Vec<oak_lsp::LocationRange>) {
50        use oak_core::{
51            language::{ElementRole, UniversalElementRole},
52            tree::RedTree,
53        };
54
55        let role = node.green.kind.role();
56        if role.universal() == UniversalElementRole::Definition {
57            for child in node.children() {
58                if let RedTree::Leaf(leaf) = child {
59                    if leaf.kind.is_universal(oak_core::language::UniversalTokenRole::Name) {
60                        let leaf_name = source.get_text_in(leaf.span.clone());
61                        if leaf_name.as_ref() == name {
62                            definitions.push(oak_lsp::LocationRange { uri: uri.to_string().into(), range: leaf.span });
63                            return;
64                        }
65                    }
66                }
67            }
68        }
69
70        for child in node.children() {
71            if let RedTree::Node(child_node) = child {
72                self.collect_definitions(&child_node, name, source, uri, definitions);
73            }
74        }
75    }
76}
77
78impl<V: Vfs + Send + Sync + 'static + oak_vfs::WritableVfs> LanguageService for RustLanguageService<V> {
79    type Lang = RustLanguage;
80    type Vfs = V;
81
82    fn vfs(&self) -> &Self::Vfs {
83        &self.vfs
84    }
85
86    fn workspace(&self) -> &oak_lsp::workspace::WorkspaceManager {
87        &self.workspace
88    }
89
90    fn get_root(&self, uri: &str) -> impl Future<Output = Option<RedNode<'_, RustLanguage>>> + Send + '_ {
91        let uri = uri.to_string();
92        async move {
93            let source = self.vfs().get_source(&uri)?;
94            let mut session = self.sessions.entry(uri.clone()).or_insert_with(|| Box::new(ParseSession::<RustLanguage>::default()));
95
96            let language = RustLanguage::default();
97            let parser = crate::parser::RustParser::new(&language);
98
99            let session_guard = session.value_mut();
100            let session_ptr: *mut ParseSession<RustLanguage> = session_guard.as_mut();
101
102            let output = parser.parse(&source, &[], session_guard.as_mut());
103
104            // Commit the generation so it can be retrieved via last_root
105            let root_green = output.result.ok()?;
106            unsafe {
107                (*session_ptr).commit_generation(root_green);
108            }
109
110            // Safety: The root is stored in the ParseSession, which is inside a Box in the DashMap.
111            // Since the DashMap is owned by self, and the Box provides stable addressing,
112            // we can safely extend the lifetime to the lifetime of self.
113            let root_ptr = unsafe {
114                let ptr = (*session_ptr).old_tree().or_else(|| Some(root_green))?;
115                std::mem::transmute::<&GreenNode<RustLanguage>, &GreenNode<RustLanguage>>(ptr)
116            };
117
118            Some(RedNode::new(root_ptr, 0))
119        }
120    }
121
122    fn definition<'a>(&'a self, uri: &'a str, range: Range<usize>) -> impl Future<Output = Vec<oak_lsp::LocationRange>> + Send + 'a {
123        let uri = uri.to_string();
124        async move {
125            let root = self.get_root(&uri).await?;
126            let source = self.vfs().get_source(&uri)?;
127            let leaf = root.leaf_at_offset(range.start)?;
128
129            if !leaf.kind.is_universal(oak_core::language::UniversalTokenRole::Name) {
130                return None;
131            }
132
133            use oak_core::Source;
134            let name = source.get_text_in(leaf.span.clone());
135            let name = name.as_ref();
136
137            // Search for definitions in all files in the workspace
138            let mut all_definitions = Vec::new();
139            let files = self.list_all_files(&uri).await;
140
141            for file_uri in files {
142                if let Some(file_root) = self.get_root(&file_uri).await {
143                    if let Some(file_source) = self.vfs().get_source(&file_uri) {
144                        self.collect_definitions(&file_root, name, &file_source, &file_uri, &mut all_definitions);
145                    }
146                }
147            }
148
149            Some(all_definitions)
150        }
151        .then(|opt| async { opt.unwrap_or_default() })
152    }
153
154    fn references<'a>(&'a self, uri: &'a str, range: Range<usize>) -> impl Future<Output = Vec<oak_lsp::LocationRange>> + Send + 'a {
155        let uri = uri.to_string();
156        async move {
157            let root = self.get_root(&uri).await?;
158            let source = self.vfs().get_source(&uri)?;
159            let leaf = root.leaf_at_offset(range.start)?;
160
161            if !leaf.kind.is_universal(oak_core::language::UniversalTokenRole::Name) {
162                return None;
163            }
164
165            use oak_core::Source;
166            let name = source.get_text_in(leaf.span.clone());
167            let name = name.to_string();
168
169            // Search in all files in the workspace
170            // Note: In a real LSP, we would use an index for performance
171            let mut all_refs = Vec::new();
172
173            // Get all workspace folders or just list files from root
174            // For this example, we'll just search the current file and any other files we can find
175            let files = self.list_all_files(&uri).await; // This is a bit hacky as it uses current file as root
176
177            for file_uri in files {
178                if let Some(file_root) = self.get_root(&file_uri).await {
179                    if let Some(file_source) = self.vfs().get_source(&file_uri) {
180                        let source_ref: &dyn oak_core::Source = &file_source;
181                        // For SimpleReferenceFinder::find, we need the text source content as a string slice
182                        // But SimpleReferenceFinder might expect different arguments based on the error
183                        // "expected `&str`, found `&<V as Vfs>::Source`"
184                        // This implies SimpleReferenceFinder::find expects `&str` for source.
185                        // We need to get the full text from the source.
186                        let source_len = source_ref.length();
187                        let full_text = source_ref.get_text_in(oak_core::Range { start: 0, end: source_len });
188                        let full_text_str = full_text.as_ref();
189
190                        let refs = oak_navigation::SimpleReferenceFinder::find(&file_root, &name, full_text_str, file_uri.clone());
191                        all_refs.extend(refs.into_iter().map(|l| oak_lsp::LocationRange { uri: l.uri, range: l.range }));
192                    }
193                }
194            }
195
196            Some(all_refs)
197        }
198        .then(|opt| async { opt.unwrap_or_default() })
199    }
200
201    fn rename<'a>(&'a self, uri: &'a str, range: Range<usize>, new_name: String) -> impl Future<Output = Option<oak_lsp::WorkspaceEdit>> + Send + 'a {
202        let uri = uri.to_string();
203        async move {
204            let refs = self.references(&uri, range).await;
205            if refs.is_empty() {
206                return None;
207            }
208
209            let mut changes = std::collections::HashMap::new();
210            for r in refs {
211                changes.entry(r.uri.to_string()).or_insert_with(Vec::new).push(oak_lsp::TextEdit { range: r.range, new_text: new_name.clone() });
212            }
213
214            Some(oak_lsp::WorkspaceEdit { changes })
215        }
216    }
217
218    fn hover(&self, uri: &str, range: Range<usize>) -> impl Future<Output = Option<oak_lsp::Hover>> + Send + '_ {
219        let uri = uri.to_string();
220        async move {
221            self.with_root(&uri, |root| {
222                // In a real implementation, you would find the specific node at offset
223                // For this example, we just check the root or simple children
224                self.hover_provider.hover(&root, range).map(|h| oak_lsp::Hover { contents: h.contents, range: h.range })
225            })
226            .await
227            .flatten()
228        }
229    }
230}