1use std::collections::HashMap;
20use std::num::NonZeroUsize;
21
22use super::task_pool::TaskPool;
23use super::InitOptions;
24use super::Session;
25use auto_lsp_core::errors::{ExtensionError, RuntimeError};
26use lsp_server::{Connection, ReqQueue};
27use lsp_types::{InitializeParams, InitializeResult, PositionEncodingKind};
28use serde::Deserialize;
29#[cfg(target_arch = "wasm32")]
30use std::fs;
31use texter::core::text::Text;
32
33#[allow(non_snake_case, reason = "JSON")]
34#[derive(Debug, Deserialize)]
35struct InitializationOptions {
36 perFileParser: HashMap<String, String>,
41}
42
43pub(crate) type TextFn = fn(String) -> Text;
45
46fn decide_encoding(encs: Option<&[PositionEncodingKind]>) -> (TextFn, PositionEncodingKind) {
47 const DEFAULT: (TextFn, PositionEncodingKind) = (Text::new_utf16, PositionEncodingKind::UTF16);
48 let Some(encs) = encs else {
49 return DEFAULT;
50 };
51
52 for enc in encs {
53 if *enc == PositionEncodingKind::UTF16 {
54 return (Text::new_utf16, enc.clone());
55 } else if *enc == PositionEncodingKind::UTF8 {
56 return (Text::new, enc.clone());
57 }
58 }
59
60 DEFAULT
61}
62
63impl<Db: salsa::Database> Session<Db> {
64 pub(crate) fn new(
65 init_options: InitOptions,
66 connection: Connection,
67 text_fn: TextFn,
68 db: Db,
69 ) -> Self {
70 let (sender, task_rx) = crossbeam_channel::unbounded();
71
72 let max_threads = std::thread::available_parallelism()
73 .unwrap_or_else(|_| NonZeroUsize::new(1).unwrap())
74 .get();
75
76 log::info!("Max threads: {max_threads}");
77
78 Self {
79 init_options,
80 connection,
81 text_fn,
82 extensions: HashMap::new(),
83 req_queue: ReqQueue::default(),
84 db,
85 task_rx,
86 task_pool: TaskPool::new_with_threads(sender, max_threads),
87 }
88 }
89
90 pub fn create(
94 mut init_options: InitOptions,
95 connection: Connection,
96 db: Db,
97 ) -> anyhow::Result<(Session<Db>, InitializeParams)> {
98 #[cfg(target_arch = "wasm32")]
101 fs::metadata("/workspace").unwrap();
102
103 log::info!("Starting LSP server");
104 log::info!("");
105
106 let (id, resp) = connection.initialize_start()?;
109 let params: InitializeParams = serde_json::from_value(resp)?;
110
111 let pos_encoding = params
112 .capabilities
113 .general
114 .as_ref()
115 .and_then(|g| g.position_encodings.as_deref());
116
117 let (t_fn, enc) = decide_encoding(pos_encoding);
118 init_options.capabilities.position_encoding = Some(enc);
119
120 let server_capabilities = serde_json::to_value(&InitializeResult {
121 capabilities: init_options.capabilities.clone(),
122 server_info: init_options.server_info.clone(),
123 })
124 .unwrap();
125
126 connection.initialize_finish(id, server_capabilities)?;
127
128 let mut session = Session::new(init_options, connection, t_fn, db);
129
130 let options = InitializationOptions::deserialize(
131 params
132 .clone()
133 .initialization_options
134 .ok_or(RuntimeError::MissingPerFileParser)?,
135 )
136 .unwrap();
137
138 for (file_extension, parser) in &options.perFileParser {
140 if !session.init_options.parsers.contains_key(parser.as_str()) {
141 return Err(RuntimeError::from(ExtensionError::UnknownParser {
142 extension: file_extension.clone(),
143 available: session.init_options.parsers.keys().cloned().collect(),
144 })
145 .into());
146 }
147 }
148
149 session.extensions = options.perFileParser;
151
152 Ok((session, params))
153 }
154}