use crate::{Result, TsgoError};
use corsa_core::fast::CompactString;
use parking_lot::Mutex;
use serde::{Serialize, de::DeserializeOwned};
use serde_json::Value;
use std::{path::Path, sync::Arc};
#[cfg(unix)]
use crate::jsonrpc::JsonRpcConnection;
#[cfg(unix)]
use std::{
io::{BufReader, BufWriter},
path::PathBuf,
};
use super::{
capabilities::{CapabilitiesResponse, RuntimeCapabilities},
changes::{UpdateSnapshotParams, UpdateSnapshotResponse},
config::{ApiMode, ApiSpawnConfig},
document::DocumentIdentifier,
driver::ClientDriver,
encoded::EncodedPayload,
profiling::SharedProfiler,
requests_core::{
ParseConfigFileRequest, ReleaseRequest, SnapshotFileRequest, SnapshotProjectFileRequest,
UpdateSnapshotRequest,
},
responses::{ConfigResponse, InitializeResponse, ProjectResponse},
snapshot::ManagedSnapshot,
spawn_stdio::{spawn_jsonrpc_stdio, spawn_msgpack_stdio},
};
#[derive(Clone)]
pub struct ApiClient {
driver: Arc<ClientDriver>,
initialized: Arc<Mutex<Option<Arc<InitializeResponse>>>>,
capabilities: Arc<Mutex<Option<Arc<CapabilitiesResponse>>>>,
runtime_capabilities: RuntimeCapabilities,
allow_unstable_upstream_calls: bool,
profiler: Option<SharedProfiler>,
}
impl ApiClient {
pub async fn spawn(config: ApiSpawnConfig) -> Result<Self> {
let driver = match config.mode {
ApiMode::AsyncJsonRpcStdio => {
let driver = spawn_jsonrpc_stdio(
&config.command,
config.filesystem.clone(),
config.request_timeout,
config.shutdown_timeout,
config.outbound_capacity,
config.observer.clone(),
)
.await?;
Arc::new(driver)
}
ApiMode::SyncMsgpackStdio => {
let driver = spawn_msgpack_stdio(
&config.command,
config.filesystem.clone(),
config.request_timeout,
config.outbound_capacity,
config.observer.clone(),
)?;
Arc::new(driver)
}
};
Ok(Self {
driver,
initialized: Arc::new(Mutex::new(None)),
capabilities: Arc::new(Mutex::new(None)),
runtime_capabilities: RuntimeCapabilities::from_spawn_config(&config),
allow_unstable_upstream_calls: config.allow_unstable_upstream_calls,
profiler: config.profiler.clone(),
})
}
#[cfg(unix)]
pub async fn connect_pipe(path: impl Into<PathBuf>) -> Result<Self> {
connect_pipe_socket(path.into()).await
}
pub async fn initialize(&self) -> Result<Arc<InitializeResponse>> {
if self.initialized.lock().is_none() {
let init: Arc<InitializeResponse> = Arc::new(
self.driver
.request_typed("initialize", &Value::Null, self.profiler.as_ref())
.await?,
);
let mut slot = self.initialized.lock();
if slot.is_none() {
*slot = Some(init.clone());
}
}
self.initialized
.lock()
.as_ref()
.cloned()
.ok_or(TsgoError::Closed("api initialize"))
}
pub async fn describe_capabilities(&self) -> Result<Arc<CapabilitiesResponse>> {
if self.capabilities.lock().is_none() {
let capabilities = match self
.raw_json_request("describeCapabilities", Value::Null)
.await
{
Ok(value) => {
let mut parsed: CapabilitiesResponse = serde_json::from_value(value)?;
parsed.runtime = parsed
.runtime
.merge_with_local(self.runtime_capabilities.clone());
parsed.runtime.capability_endpoint = true;
Arc::new(parsed)
}
Err(TsgoError::Rpc(error)) if error.code == -32601 => Arc::new(
CapabilitiesResponse::fallback(self.runtime_capabilities.clone()),
),
Err(error) => return Err(error),
};
let mut slot = self.capabilities.lock();
if slot.is_none() {
*slot = Some(capabilities.clone());
}
}
self.capabilities
.lock()
.as_ref()
.cloned()
.ok_or(TsgoError::Closed("api describeCapabilities"))
}
pub async fn parse_config_file(
&self,
file: impl Into<DocumentIdentifier>,
) -> Result<ConfigResponse> {
self.initialize().await?;
let request = ParseConfigFileRequest { file: file.into() };
self.request_after_initialize("parseConfigFile", &request)
.await
}
pub async fn update_snapshot(&self, params: UpdateSnapshotParams) -> Result<ManagedSnapshot> {
if params.overlay_changes.is_some() {
self.require_overlay_update_capability().await?;
}
self.initialize().await?;
let request = UpdateSnapshotRequest {
open_project: params.open_project,
file_changes: params.file_changes,
overlay_changes: params.overlay_changes,
};
let response: UpdateSnapshotResponse = self
.request_after_initialize("updateSnapshot", &request)
.await?;
Ok(super::snapshot::ManagedSnapshot::new(
self.clone(),
response,
))
}
pub async fn get_default_project_for_file(
&self,
snapshot: super::SnapshotHandle,
file: impl Into<DocumentIdentifier>,
) -> Result<Option<ProjectResponse>> {
self.initialize().await?;
let request = SnapshotFileRequest {
snapshot,
file: file.into(),
};
self.request_optional_after_initialize("getDefaultProjectForFile", &request)
.await
}
pub async fn get_source_file(
&self,
snapshot: super::SnapshotHandle,
project: super::ProjectHandle,
file: impl Into<DocumentIdentifier>,
) -> Result<Option<EncodedPayload>> {
self.initialize().await?;
let request = SnapshotProjectFileRequest {
snapshot,
project,
file: file.into(),
};
self.request_binary_after_initialize("getSourceFile", &request)
.await
}
pub async fn close(&self) -> Result<()> {
self.driver.close().await
}
pub fn allows_unstable_upstream_calls(&self) -> bool {
self.allow_unstable_upstream_calls
}
pub async fn raw_json_request(&self, method: &str, params: Value) -> Result<Value> {
self.initialize().await?;
if self.profiler.is_some() {
self.driver
.request_typed(method, ¶ms, self.profiler.as_ref())
.await
} else {
self.driver.request_json(method, params).await
}
}
pub async fn raw_binary_request(
&self,
method: &str,
params: Value,
) -> Result<Option<EncodedPayload>> {
self.initialize().await?;
if self.profiler.is_some() {
Ok(self
.driver
.request_binary_typed(method, ¶ms, self.profiler.as_ref())
.await?
.map(EncodedPayload::new))
} else {
Ok(self
.driver
.request_binary(method, params)
.await?
.map(EncodedPayload::new))
}
}
pub(crate) async fn release_handle(&self, handle: &str) -> Result<()> {
let request = ReleaseRequest { handle };
let _: Value = self.request_after_initialize("release", &request).await?;
Ok(())
}
pub(crate) async fn call<T, P>(&self, method: &str, params: P) -> Result<T>
where
T: DeserializeOwned,
P: Serialize,
{
self.initialize().await?;
self.request_after_initialize(method, ¶ms).await
}
pub(crate) async fn call_optional<T, P>(&self, method: &str, params: P) -> Result<Option<T>>
where
T: DeserializeOwned,
P: Serialize,
{
self.initialize().await?;
self.request_optional_after_initialize(method, ¶ms)
.await
}
pub(crate) async fn call_optional_binary<P>(
&self,
method: &str,
params: P,
) -> Result<Option<EncodedPayload>>
where
P: Serialize,
{
self.initialize().await?;
self.request_binary_after_initialize(method, ¶ms).await
}
pub(crate) async fn require_overlay_update_capability(&self) -> Result<()> {
let capabilities = self.describe_capabilities().await?;
if capabilities.overlay.update_snapshot_overlay_changes {
return Ok(());
}
Err(TsgoError::Unsupported(
"updateSnapshot.overlayChanges is not supported by this runtime; check describeCapabilities before sending in-memory overlays",
))
}
pub(crate) fn map_missing_method(
error: TsgoError,
unsupported_message: &'static str,
) -> TsgoError {
match error {
TsgoError::Rpc(rpc) if rpc.code == -32601 => {
TsgoError::Unsupported(unsupported_message)
}
other => other,
}
}
async fn request_after_initialize<T, P>(&self, method: &str, params: &P) -> Result<T>
where
T: DeserializeOwned,
P: Serialize + ?Sized,
{
self.driver
.request_typed(method, params, self.profiler.as_ref())
.await
}
async fn request_optional_after_initialize<T, P>(
&self,
method: &str,
params: &P,
) -> Result<Option<T>>
where
T: DeserializeOwned,
P: Serialize + ?Sized,
{
let value: Value = self.request_after_initialize(method, params).await?;
if value.is_null() {
Ok(None)
} else {
Ok(Some(serde_json::from_value(value)?))
}
}
async fn request_binary_after_initialize<P>(
&self,
method: &str,
params: &P,
) -> Result<Option<EncodedPayload>>
where
P: Serialize + ?Sized,
{
Ok(self
.driver
.request_binary_typed(method, params, self.profiler.as_ref())
.await?
.map(EncodedPayload::new))
}
}
#[cfg(unix)]
async fn connect_pipe_socket(path: PathBuf) -> Result<ApiClient> {
let stream = std::os::unix::net::UnixStream::connect(path)?;
let reader = BufReader::new(stream.try_clone()?);
let writer = BufWriter::new(stream);
let rpc = JsonRpcConnection::spawn(reader, writer, Default::default());
Ok(ApiClient {
driver: Arc::new(ClientDriver::JsonRpc {
rpc,
process: None,
shutdown_timeout: std::time::Duration::from_secs(2),
}),
initialized: Arc::new(Mutex::new(None)),
capabilities: Arc::new(Mutex::new(None)),
runtime_capabilities: RuntimeCapabilities {
kind: Some(CompactString::from("pipe")),
executable: None,
transport: Some(CompactString::from("jsonrpc")),
capability_endpoint: false,
},
allow_unstable_upstream_calls: false,
profiler: None,
})
}
impl RuntimeCapabilities {
fn from_spawn_config(config: &ApiSpawnConfig) -> Self {
let executable = config.command.executable().to_string_lossy().to_string();
Self {
kind: infer_runtime_kind(config.command.executable()),
executable: Some(CompactString::from(executable)),
transport: Some(match config.mode {
ApiMode::AsyncJsonRpcStdio => CompactString::from("jsonrpc"),
ApiMode::SyncMsgpackStdio => CompactString::from("msgpack"),
}),
capability_endpoint: false,
}
}
}
fn infer_runtime_kind(path: &Path) -> Option<CompactString> {
let normalized = path.to_string_lossy().to_ascii_lowercase();
let kind = if normalized.contains("mock_tsgo") {
"mock-corsa"
} else if normalized.contains("native-preview") {
"native-preview"
} else if normalized.ends_with("/tsgo")
|| normalized.ends_with("\\tsgo.exe")
|| normalized.ends_with("\\tsgo")
|| normalized.ends_with("/tsgo.exe")
{
"tsgo"
} else {
"custom"
};
Some(CompactString::from(kind))
}