use std::collections::HashMap;
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::process::Stdio;
use std::sync::{Arc, Mutex};
use std::time::UNIX_EPOCH;
use futures_util::{Sink, SinkExt, StreamExt};
use rmpv::Value;
use rpc_runtime_core::{InstanceId, MethodId, ServiceGuid};
use rpc_runtime_errors::{RuntimeError, RuntimeErrorCode};
use rpc_runtime_server::{
ConnectionCleanupFuture, HandlerFuture, RpcCallContext, RpcConnectionCleanupSink, RpcServer,
RpcServerBuilder, RpcServerSecurityConfig,
};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt, SeekFrom};
use tokio::net::tcp::OwnedWriteHalf;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Mutex as AsyncMutex;
use tokio::task::JoinHandle;
use tokio_tungstenite::tungstenite::Message;
use uuid::Uuid;
mod archive;
mod filesystem;
pub mod generated;
pub const RUNTIME_INSTANCE: &str = "tripley.native.runtime";
pub const FS_INSTANCE: &str = "tripley.native.fs";
pub const ARCHIVE_INSTANCE: &str = "tripley.native.archive";
pub const TCP_INSTANCE: &str = "tripley.native.tcp";
pub const WEBSOCKET_INSTANCE: &str = "tripley.native.websocket";
pub const SQLITE_INSTANCE: &str = "tripley.native.sqlite";
pub const SYSTEM_INSTANCE: &str = "tripley.native.system";
pub const RUNTIME_SERVICE_GUID: &str = "52c6943d-d956-4a42-b69e-51e94399c001";
pub const FS_SERVICE_GUID: &str = "52c6943d-d956-4a42-b69e-51e94399c002";
pub const ARCHIVE_SERVICE_GUID: &str = "1e7d1d50-721c-4b14-ad7e-7793170cea05";
pub const TCP_SERVICE_GUID: &str = "52c6943d-d956-4a42-b69e-51e94399c003";
pub const WEBSOCKET_SERVICE_GUID: &str = "52c6943d-d956-4a42-b69e-51e94399c004";
pub const SQLITE_SERVICE_GUID: &str = "52c6943d-d956-4a42-b69e-51e94399c005";
pub const SYSTEM_SERVICE_GUID: &str = "52c6943d-d956-4a42-b69e-51e94399c006";
pub const EVENT_NOTIFICATION_ID: u32 = 1;
pub trait NativeRpcProvider: Send + Sync {
fn register(&self, builder: &mut RpcServerBuilder);
fn capabilities(&self) -> Vec<&'static str>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct NativeServiceSet {
pub fs: bool,
pub archive: bool,
pub tcp: bool,
pub websocket: bool,
pub sqlite: bool,
pub system: bool,
}
impl NativeServiceSet {
pub const fn all() -> Self {
Self {
fs: true,
archive: true,
tcp: true,
websocket: true,
sqlite: true,
system: true,
}
}
pub const fn runtime_only() -> Self {
Self {
fs: false,
archive: false,
tcp: false,
websocket: false,
sqlite: false,
system: false,
}
}
pub fn capabilities(self) -> Vec<&'static str> {
let mut capabilities = vec!["runtime.info"];
if self.fs {
capabilities.push("fs");
}
if self.archive {
capabilities.push("archive");
}
if self.tcp {
capabilities.push("tcp.client");
capabilities.push("tcp.server");
}
if self.websocket {
capabilities.push("websocket.client");
capabilities.push("websocket.server");
}
if self.sqlite {
capabilities.push("sqlite");
}
if self.system {
capabilities.push("system.shutdown");
capabilities.push("system.reboot");
}
capabilities
}
}
impl Default for NativeServiceSet {
fn default() -> Self {
Self::all()
}
}
#[derive(Clone)]
pub struct NativeRpcServerOptions {
pub policy: Arc<dyn NativePolicy>,
pub services: NativeServiceSet,
pub security: RpcServerSecurityConfig,
pub providers: Vec<Arc<dyn NativeRpcProvider>>,
}
impl Default for NativeRpcServerOptions {
fn default() -> Self {
Self {
policy: Arc::new(DevPermissivePolicy),
services: NativeServiceSet::all(),
security: RpcServerSecurityConfig::default(),
providers: Vec::new(),
}
}
}
#[derive(Clone)]
pub struct NativeState {
open_files: Arc<AsyncMutex<HashMap<String, OpenFileResource>>>,
tcp_sockets: Arc<AsyncMutex<HashMap<String, TcpSocketResource>>>,
tcp_servers: Arc<AsyncMutex<HashMap<String, TaskResource>>>,
websocket_sockets: Arc<AsyncMutex<HashMap<String, WebSocketResource>>>,
websocket_servers: Arc<AsyncMutex<HashMap<String, TaskResource>>>,
sqlite: Arc<Mutex<HashMap<String, SqliteResource>>>,
policy: Arc<dyn NativePolicy>,
services: NativeServiceSet,
provider_capabilities: Arc<Vec<&'static str>>,
}
impl Default for NativeState {
fn default() -> Self {
Self {
open_files: Arc::new(AsyncMutex::new(HashMap::new())),
tcp_sockets: Arc::new(AsyncMutex::new(HashMap::new())),
tcp_servers: Arc::new(AsyncMutex::new(HashMap::new())),
websocket_sockets: Arc::new(AsyncMutex::new(HashMap::new())),
websocket_servers: Arc::new(AsyncMutex::new(HashMap::new())),
sqlite: Arc::new(Mutex::new(HashMap::new())),
policy: Arc::new(DevPermissivePolicy),
services: NativeServiceSet::all(),
provider_capabilities: Arc::new(Vec::new()),
}
}
}
impl NativeState {
async fn dispose_connection_resources(&self, connection_id: u64) {
let file_ids = {
let files = self.open_files.lock().await;
files
.iter()
.filter_map(|(id, resource)| {
(resource.owner_connection_id == connection_id).then(|| id.clone())
})
.collect::<Vec<_>>()
};
for id in file_ids {
self.open_files.lock().await.remove(&id);
}
let tcp_socket_ids = {
let sockets = self.tcp_sockets.lock().await;
sockets
.iter()
.filter_map(|(id, resource)| {
(resource.owner_connection_id == connection_id).then(|| id.clone())
})
.collect::<Vec<_>>()
};
for id in tcp_socket_ids {
self.tcp_sockets.lock().await.remove(&id);
}
let tcp_server_ids = {
let servers = self.tcp_servers.lock().await;
servers
.iter()
.filter_map(|(id, resource)| {
(resource.owner_connection_id == connection_id).then(|| id.clone())
})
.collect::<Vec<_>>()
};
for id in tcp_server_ids {
if let Some(resource) = self.tcp_servers.lock().await.remove(&id) {
resource.task.abort();
}
}
let websocket_ids = {
let sockets = self.websocket_sockets.lock().await;
sockets
.iter()
.filter_map(|(id, resource)| {
(resource.owner_connection_id == connection_id).then(|| id.clone())
})
.collect::<Vec<_>>()
};
for id in websocket_ids {
if let Some(resource) = self.websocket_sockets.lock().await.remove(&id) {
let _ = resource
.writer
.lock()
.await
.send(Message::Close(None))
.await;
}
}
let websocket_server_ids = {
let servers = self.websocket_servers.lock().await;
servers
.iter()
.filter_map(|(id, resource)| {
(resource.owner_connection_id == connection_id).then(|| id.clone())
})
.collect::<Vec<_>>()
};
for id in websocket_server_ids {
if let Some(resource) = self.websocket_servers.lock().await.remove(&id) {
resource.task.abort();
}
}
let sqlite_ids = {
let dbs = self.sqlite.lock().expect("sqlite lock");
dbs.iter()
.filter_map(|(id, resource)| {
(resource.owner_connection_id == connection_id).then(|| id.clone())
})
.collect::<Vec<_>>()
};
let mut dbs = self.sqlite.lock().expect("sqlite lock");
for id in sqlite_ids {
dbs.remove(&id);
}
}
}
impl RpcConnectionCleanupSink for NativeState {
fn cleanup_connection<'a>(&'a self, connection_id: u64) -> ConnectionCleanupFuture<'a> {
Box::pin(async move {
self.dispose_connection_resources(connection_id).await;
})
}
}
type BoxedWebSocketWriter =
std::pin::Pin<Box<dyn Sink<Message, Error = tokio_tungstenite::tungstenite::Error> + Send>>;
type WebSocketWriter = Arc<AsyncMutex<BoxedWebSocketWriter>>;
#[derive(Clone)]
struct OpenFileResource {
owner_connection_id: u64,
file: Arc<AsyncMutex<tokio::fs::File>>,
}
#[derive(Clone)]
struct TcpSocketResource {
owner_connection_id: u64,
writer: Arc<AsyncMutex<OwnedWriteHalf>>,
}
#[derive(Clone)]
struct WebSocketResource {
owner_connection_id: u64,
writer: WebSocketWriter,
}
struct TaskResource {
owner_connection_id: u64,
task: JoinHandle<()>,
}
struct SqliteResource {
owner_connection_id: u64,
connection: rusqlite::Connection,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct NativePolicyConfig {
pub filesystem: FileSystemPolicyConfig,
pub network: NetworkPolicyConfig,
pub sqlite: SqlitePolicyConfig,
pub power: PowerPolicyConfig,
}
impl NativePolicyConfig {
pub fn dev_permissive() -> Self {
Self {
filesystem: FileSystemPolicyConfig {
read: vec![PathBuf::from("/")],
write: vec![PathBuf::from("/")],
},
network: NetworkPolicyConfig {
tcp: vec![NetworkRule::any()],
websocket: vec![NetworkRule::any()],
},
sqlite: SqlitePolicyConfig {
paths: vec![PathBuf::from("/")],
},
power: PowerPolicyConfig {
shutdown: true,
reboot: true,
},
}
}
pub fn from_json(value: &str) -> Result<Self, RuntimeError> {
serde_json::from_str(value).map_err(|error| {
RuntimeError::runtime(
RuntimeErrorCode::PayloadDecodeFailed,
format!("invalid native policy config JSON: {error}"),
)
})
}
pub fn allow_fs_read(mut self, path: impl Into<PathBuf>) -> Self {
self.filesystem.read.push(path.into());
self
}
pub fn allow_fs_write(mut self, path: impl Into<PathBuf>) -> Self {
self.filesystem.write.push(path.into());
self
}
pub fn allow_fs_read_write(mut self, path: impl Into<PathBuf>) -> Self {
let path = path.into();
self.filesystem.read.push(path.clone());
self.filesystem.write.push(path);
self
}
pub fn allow_tcp(mut self, host: impl Into<String>, port: Option<u16>) -> Self {
self.network.tcp.push(NetworkRule {
host: host.into(),
port,
});
self
}
pub fn allow_websocket(mut self, host: impl Into<String>, port: Option<u16>) -> Self {
self.network.websocket.push(NetworkRule {
host: host.into(),
port,
});
self
}
pub fn allow_sqlite(mut self, path: impl Into<PathBuf>) -> Self {
self.sqlite.paths.push(path.into());
self
}
pub fn allow_shutdown(mut self) -> Self {
self.power.shutdown = true;
self
}
pub fn allow_reboot(mut self) -> Self {
self.power.reboot = true;
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct FileSystemPolicyConfig {
pub read: Vec<PathBuf>,
pub write: Vec<PathBuf>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct NetworkPolicyConfig {
pub tcp: Vec<NetworkRule>,
pub websocket: Vec<NetworkRule>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct NetworkRule {
pub host: String,
pub port: Option<u16>,
}
impl Default for NetworkRule {
fn default() -> Self {
Self {
host: String::new(),
port: None,
}
}
}
impl NetworkRule {
pub fn any() -> Self {
Self {
host: "*".to_string(),
port: None,
}
}
fn matches(&self, host: &str, port: u16) -> bool {
(self.host == "*" || self.host.eq_ignore_ascii_case(host))
&& self.port.is_none_or(|allowed| allowed == port)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct SqlitePolicyConfig {
pub paths: Vec<PathBuf>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(default, rename_all = "camelCase")]
pub struct PowerPolicyConfig {
pub shutdown: bool,
pub reboot: bool,
}
pub trait NativePolicy: Send + Sync {
fn mode(&self) -> &'static str;
fn allow_fs_path(&self, _operation: &str, _path: &std::path::Path) -> Result<(), RuntimeError> {
Ok(())
}
fn allow_network(&self, _operation: &str, _host: &str, _port: u16) -> Result<(), RuntimeError> {
Ok(())
}
fn allow_sqlite_path(&self, _path: &std::path::Path) -> Result<(), RuntimeError> {
Ok(())
}
fn allow_power(&self, _operation: &str) -> Result<(), RuntimeError> {
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct ConfiguredNativePolicy {
config: NativePolicyConfig,
}
impl ConfiguredNativePolicy {
pub fn new(config: NativePolicyConfig) -> Self {
Self { config }
}
pub fn config(&self) -> &NativePolicyConfig {
&self.config
}
}
impl NativePolicy for ConfiguredNativePolicy {
fn mode(&self) -> &'static str {
"configured"
}
fn allow_fs_path(&self, operation: &str, path: &Path) -> Result<(), RuntimeError> {
let allowed = if is_fs_write_operation(operation) {
&self.config.filesystem.write
} else {
&self.config.filesystem.read
};
if allowed.iter().any(|root| path_is_under(path, root)) {
return Ok(());
}
Err(policy_denied("filesystem", operation))
}
fn allow_network(&self, operation: &str, host: &str, port: u16) -> Result<(), RuntimeError> {
let rules = if operation.starts_with("websocket_") {
&self.config.network.websocket
} else {
&self.config.network.tcp
};
if rules.iter().any(|rule| rule.matches(host, port)) {
return Ok(());
}
Err(policy_denied("network", operation))
}
fn allow_sqlite_path(&self, path: &Path) -> Result<(), RuntimeError> {
if self
.config
.sqlite
.paths
.iter()
.any(|root| path_is_under(path, root))
{
return Ok(());
}
Err(policy_denied("sqlite", "open"))
}
fn allow_power(&self, operation: &str) -> Result<(), RuntimeError> {
let allowed = match operation {
"shutdown" => self.config.power.shutdown,
"reboot" => self.config.power.reboot,
_ => false,
};
if allowed {
return Ok(());
}
Err(policy_denied("power", operation))
}
}
#[derive(Debug, Default)]
pub struct DevPermissivePolicy;
impl NativePolicy for DevPermissivePolicy {
fn mode(&self) -> &'static str {
"dev-permissive"
}
}
pub fn build_native_rpc_server() -> RpcServer {
build_native_rpc_server_with_policy(Arc::new(DevPermissivePolicy))
}
pub fn build_native_rpc_server_with_config(config: NativePolicyConfig) -> RpcServer {
build_native_rpc_server_with_policy(Arc::new(ConfiguredNativePolicy::new(config)))
}
pub fn build_native_rpc_server_with_policy(policy: Arc<dyn NativePolicy>) -> RpcServer {
build_native_rpc_server_with_options(NativeRpcServerOptions {
policy,
..NativeRpcServerOptions::default()
})
}
pub fn build_native_rpc_server_with_options(options: NativeRpcServerOptions) -> RpcServer {
let provider_capabilities = options
.providers
.iter()
.flat_map(|provider| provider.capabilities())
.collect::<Vec<_>>();
let state = NativeState {
policy: options.policy,
services: options.services,
provider_capabilities: Arc::new(provider_capabilities),
..NativeState::default()
};
let mut builder = RpcServerBuilder::new();
builder.set_security(options.security);
builder.set_connection_cleanup_sink(Arc::new(state.clone()));
register(
&mut builder,
RUNTIME_INSTANCE,
RUNTIME_SERVICE_GUID,
1..=3,
NativeHandler::new(state.clone(), ServiceKind::Runtime),
);
if options.services.fs {
register(
&mut builder,
FS_INSTANCE,
FS_SERVICE_GUID,
1..=17,
NativeHandler::new(state.clone(), ServiceKind::Fs),
);
}
if options.services.archive {
register(
&mut builder,
ARCHIVE_INSTANCE,
ARCHIVE_SERVICE_GUID,
1..=2,
NativeHandler::new(state.clone(), ServiceKind::Archive),
);
}
if options.services.tcp {
register(
&mut builder,
TCP_INSTANCE,
TCP_SERVICE_GUID,
1..=6,
NativeHandler::new(state.clone(), ServiceKind::Tcp),
);
}
if options.services.websocket {
register(
&mut builder,
WEBSOCKET_INSTANCE,
WEBSOCKET_SERVICE_GUID,
1..=6,
NativeHandler::new(state.clone(), ServiceKind::WebSocket),
);
}
if options.services.sqlite {
register(
&mut builder,
SQLITE_INSTANCE,
SQLITE_SERVICE_GUID,
1..=7,
NativeHandler::new(state.clone(), ServiceKind::Sqlite),
);
}
if options.services.system {
register(
&mut builder,
SYSTEM_INSTANCE,
SYSTEM_SERVICE_GUID,
1..=3,
NativeHandler::new(state.clone(), ServiceKind::System),
);
}
for provider in &options.providers {
provider.register(&mut builder);
}
builder.build()
}
pub fn register_native_named_instance(
builder: &mut RpcServerBuilder,
name: &str,
guid: &str,
methods: impl IntoIterator<Item = u32>,
handler: Arc<dyn rpc_runtime_server::RpcServiceHandler>,
) -> InstanceId {
let guid = Uuid::parse_str(guid).expect("static service guid");
builder.register_named_instance(name, ServiceGuid::new(guid), methods, handler)
}
fn register(
builder: &mut RpcServerBuilder,
name: &str,
guid: &str,
methods: impl IntoIterator<Item = u32>,
handler: NativeHandler,
) {
let _ = register_native_named_instance(builder, name, guid, methods, Arc::new(handler));
}
#[derive(Clone, Copy)]
enum ServiceKind {
Runtime,
Fs,
Archive,
Tcp,
WebSocket,
Sqlite,
System,
}
struct NativeHandler {
state: NativeState,
kind: ServiceKind,
}
impl NativeHandler {
fn new(state: NativeState, kind: ServiceKind) -> Self {
Self { state, kind }
}
}
impl rpc_runtime_server::RpcServiceHandler for NativeHandler {
fn call(&self, ctx: RpcCallContext, method_id: MethodId, payload: Value) -> HandlerFuture {
let state = self.state.clone();
let kind = self.kind;
Box::pin(async move {
match kind {
ServiceKind::Runtime => runtime_call(state, ctx, method_id.get(), payload).await,
ServiceKind::Fs => filesystem::fs_call(state, ctx, method_id.get(), payload).await,
ServiceKind::Archive => {
archive::archive_call(state, method_id.get(), payload).await
}
ServiceKind::Tcp => tcp_call(state, ctx, method_id.get(), payload).await,
ServiceKind::WebSocket => {
websocket_call(state, ctx, method_id.get(), payload).await
}
ServiceKind::Sqlite => sqlite_call(state, ctx, method_id.get(), payload).await,
ServiceKind::System => system_call(state, method_id.get(), payload).await,
}
})
}
}
async fn runtime_call(
state: NativeState,
ctx: RpcCallContext,
method: u32,
_payload: Value,
) -> Result<Value, RuntimeError> {
match method {
1 => Ok(array(vec![
string(std::env::consts::OS),
string(std::env::consts::ARCH),
string(std::env::consts::FAMILY),
std::env::current_exe()
.ok()
.and_then(|path| path.to_str().map(string))
.unwrap_or(Value::Nil),
string_list(native_capabilities(&state)),
string(state.policy.mode()),
])),
2 => Ok(string_list(native_capabilities(&state))),
3 => {
state
.dispose_connection_resources(ctx.connection_id())
.await;
Ok(empty())
}
_ => Err(method_not_found(method)),
}
}
fn native_capabilities(state: &NativeState) -> Vec<&'static str> {
let mut capabilities = state.services.capabilities();
capabilities.extend(state.provider_capabilities.iter().copied());
capabilities
}
async fn tcp_call(
state: NativeState,
ctx: RpcCallContext,
method: u32,
payload: Value,
) -> Result<Value, RuntimeError> {
match method {
1 => {
let host = string_arg(&payload, 0)?;
let port = u16_arg(&payload, 1)?;
state.policy.allow_network("tcp_connect", &host, port)?;
let stream = TcpStream::connect((host.as_str(), port))
.await
.map_err(io_error)?;
let id = resource_id("tcp");
install_tcp_socket(state, ctx.clone(), id.clone(), None, stream).await;
Ok(array(vec![string(id)]))
}
2 => {
let id = string_arg(&payload, 0)?;
let bytes = bytes_arg(&payload, 1)?;
let writer = socket_writer(&state, &id).await?;
writer
.lock()
.await
.write_all(&bytes)
.await
.map_err(io_error)?;
Ok(empty())
}
3 => {
let id = string_arg(&payload, 0)?;
let writer = socket_writer(&state, &id).await?;
writer.lock().await.shutdown().await.map_err(io_error)?;
Ok(empty())
}
4 => {
let id = string_arg(&payload, 0)?;
state.tcp_sockets.lock().await.remove(&id);
Ok(empty())
}
5 => {
let host = string_arg(&payload, 0)?;
let port = u16_arg(&payload, 1)?;
state.policy.allow_network("tcp_listen", &host, port)?;
let listener = TcpListener::bind((host.as_str(), port))
.await
.map_err(io_error)?;
let addr = listener.local_addr().map_err(io_error)?;
let id = resource_id("tcp-server");
let server_id = id.clone();
let state_for_task = state.clone();
let ctx_for_task = ctx.clone();
let task = tokio::spawn(async move {
tcp_accept_loop(state_for_task, ctx_for_task, server_id, listener).await;
});
state.tcp_servers.lock().await.insert(
id.clone(),
TaskResource {
owner_connection_id: ctx.connection_id(),
task,
},
);
Ok(array(vec![string(id), string(addr.to_string())]))
}
6 => {
let id = string_arg(&payload, 0)?;
if let Some(resource) = state.tcp_servers.lock().await.remove(&id) {
resource.task.abort();
}
Ok(empty())
}
_ => Err(method_not_found(method)),
}
}
async fn install_tcp_socket(
state: NativeState,
ctx: RpcCallContext,
id: String,
parent_id: Option<String>,
stream: TcpStream,
) {
let (mut reader, writer) = stream.into_split();
state.tcp_sockets.lock().await.insert(
id.clone(),
TcpSocketResource {
owner_connection_id: ctx.connection_id(),
writer: Arc::new(AsyncMutex::new(writer)),
},
);
let state_for_task = state.clone();
tokio::spawn(async move {
let mut buf = vec![0_u8; 8192];
loop {
match reader.read(&mut buf).await {
Ok(0) => {
let _ = notify_tcp(&ctx, "close", &id, parent_id.as_deref(), None, None).await;
state_for_task.tcp_sockets.lock().await.remove(&id);
break;
}
Ok(n) => {
let _ = notify_tcp(
&ctx,
"data",
&id,
parent_id.as_deref(),
Some(&buf[..n]),
None,
)
.await;
}
Err(error) => {
let _ = notify_tcp(
&ctx,
"error",
&id,
parent_id.as_deref(),
None,
Some(error.to_string()),
)
.await;
state_for_task.tcp_sockets.lock().await.remove(&id);
break;
}
}
}
});
}
async fn tcp_accept_loop(
state: NativeState,
ctx: RpcCallContext,
server_id: String,
listener: TcpListener,
) {
loop {
match listener.accept().await {
Ok((stream, _)) => {
let socket_id = resource_id("tcp");
install_tcp_socket(
state.clone(),
ctx.clone(),
socket_id.clone(),
Some(server_id.clone()),
stream,
)
.await;
let _ =
notify_tcp(&ctx, "connection", &socket_id, Some(&server_id), None, None).await;
}
Err(error) => {
let _ = notify_tcp(
&ctx,
"error",
&server_id,
None,
None,
Some(error.to_string()),
)
.await;
break;
}
}
}
}
async fn socket_writer(
state: &NativeState,
id: &str,
) -> Result<Arc<AsyncMutex<OwnedWriteHalf>>, RuntimeError> {
state
.tcp_sockets
.lock()
.await
.get(id)
.map(|resource| resource.writer.clone())
.ok_or_else(|| runtime_error(format!("TCP socket `{id}` was not found")))
}
async fn notify_tcp(
ctx: &RpcCallContext,
kind: &str,
id: &str,
parent_id: Option<&str>,
data: Option<&[u8]>,
message: Option<String>,
) -> Result<(), RuntimeError> {
ctx.notify_bound(
EVENT_NOTIFICATION_ID,
array(vec![
string(kind),
string(id),
parent_id.map(string).unwrap_or(Value::Nil),
data.map(|bytes| Value::Binary(bytes.to_vec()))
.unwrap_or(Value::Nil),
message.map(string).unwrap_or(Value::Nil),
]),
)
.await
}
async fn websocket_call(
state: NativeState,
ctx: RpcCallContext,
method: u32,
payload: Value,
) -> Result<Value, RuntimeError> {
match method {
1 => {
let url = string_arg(&payload, 0)?;
let (host, port) = websocket_url_authority(&url)?;
state
.policy
.allow_network("websocket_connect", &host, port)?;
let (stream, _) = tokio_tungstenite::connect_async(&url)
.await
.map_err(ws_error)?;
let id = resource_id("ws");
install_websocket(state, ctx, id.clone(), None, stream).await;
Ok(array(vec![string(id)]))
}
2 => {
let id = string_arg(&payload, 0)?;
let text = string_arg(&payload, 1)?;
websocket_writer(&state, &id)
.await?
.lock()
.await
.send(Message::Text(text.into()))
.await
.map_err(ws_error)?;
Ok(empty())
}
3 => {
let id = string_arg(&payload, 0)?;
let bytes = bytes_arg(&payload, 1)?;
websocket_writer(&state, &id)
.await?
.lock()
.await
.send(Message::Binary(bytes.into()))
.await
.map_err(ws_error)?;
Ok(empty())
}
4 => {
let id = string_arg(&payload, 0)?;
if let Some(resource) = state.websocket_sockets.lock().await.remove(&id) {
let _ = resource
.writer
.lock()
.await
.send(Message::Close(None))
.await;
}
Ok(empty())
}
5 => {
let host = string_arg(&payload, 0)?;
let port = u16_arg(&payload, 1)?;
state
.policy
.allow_network("websocket_listen", &host, port)?;
let listener = TcpListener::bind((host.as_str(), port))
.await
.map_err(io_error)?;
let addr = listener.local_addr().map_err(io_error)?;
let id = resource_id("ws-server");
let server_id = id.clone();
let state_for_task = state.clone();
let ctx_for_task = ctx.clone();
let task = tokio::spawn(async move {
websocket_accept_loop(state_for_task, ctx_for_task, server_id, listener).await;
});
state.websocket_servers.lock().await.insert(
id.clone(),
TaskResource {
owner_connection_id: ctx.connection_id(),
task,
},
);
Ok(array(vec![string(id), string(format!("ws://{addr}"))]))
}
6 => {
let id = string_arg(&payload, 0)?;
if let Some(resource) = state.websocket_servers.lock().await.remove(&id) {
resource.task.abort();
}
Ok(empty())
}
_ => Err(method_not_found(method)),
}
}
async fn install_websocket<S>(
state: NativeState,
ctx: RpcCallContext,
id: String,
parent_id: Option<String>,
stream: tokio_tungstenite::WebSocketStream<S>,
) where
tokio_tungstenite::WebSocketStream<S>: futures_util::Stream<Item = Result<Message, tokio_tungstenite::tungstenite::Error>>
+ futures_util::Sink<Message, Error = tokio_tungstenite::tungstenite::Error>
+ Unpin
+ Send
+ 'static,
{
let (writer, mut reader) = stream.split();
state.websocket_sockets.lock().await.insert(
id.clone(),
WebSocketResource {
owner_connection_id: ctx.connection_id(),
writer: Arc::new(AsyncMutex::new(Box::pin(writer))),
},
);
let state_for_task = state.clone();
tokio::spawn(async move {
while let Some(message) = reader.next().await {
match message {
Ok(Message::Text(text)) => {
let _ = notify_ws(
&ctx,
"text",
&id,
parent_id.as_deref(),
None,
Some(&text),
None,
)
.await;
}
Ok(Message::Binary(bytes)) => {
let _ = notify_ws(
&ctx,
"binary",
&id,
parent_id.as_deref(),
Some(&bytes),
None,
None,
)
.await;
}
Ok(Message::Close(_)) => {
let _ =
notify_ws(&ctx, "close", &id, parent_id.as_deref(), None, None, None).await;
break;
}
Ok(_) => {}
Err(error) => {
let _ = notify_ws(
&ctx,
"error",
&id,
parent_id.as_deref(),
None,
None,
Some(error.to_string()),
)
.await;
break;
}
}
}
state_for_task.websocket_sockets.lock().await.remove(&id);
});
}
async fn websocket_accept_loop(
state: NativeState,
ctx: RpcCallContext,
server_id: String,
listener: TcpListener,
) {
loop {
match listener.accept().await {
Ok((stream, _)) => match tokio_tungstenite::accept_async(stream).await {
Ok(ws) => {
let id = resource_id("ws");
install_websocket(
state.clone(),
ctx.clone(),
id.clone(),
Some(server_id.clone()),
ws,
)
.await;
let _ = notify_ws(&ctx, "connection", &id, Some(&server_id), None, None, None)
.await;
}
Err(error) => {
let _ = notify_ws(
&ctx,
"error",
&server_id,
None,
None,
None,
Some(error.to_string()),
)
.await;
}
},
Err(error) => {
let _ = notify_ws(
&ctx,
"error",
&server_id,
None,
None,
None,
Some(error.to_string()),
)
.await;
break;
}
}
}
}
async fn websocket_writer(state: &NativeState, id: &str) -> Result<WebSocketWriter, RuntimeError> {
state
.websocket_sockets
.lock()
.await
.get(id)
.map(|resource| resource.writer.clone())
.ok_or_else(|| runtime_error(format!("WebSocket `{id}` was not found")))
}
async fn notify_ws(
ctx: &RpcCallContext,
kind: &str,
id: &str,
parent_id: Option<&str>,
data: Option<&[u8]>,
text: Option<&str>,
message: Option<String>,
) -> Result<(), RuntimeError> {
ctx.notify_bound(
EVENT_NOTIFICATION_ID,
array(vec![
string(kind),
string(id),
parent_id.map(string).unwrap_or(Value::Nil),
data.map(|bytes| Value::Binary(bytes.to_vec()))
.unwrap_or(Value::Nil),
text.map(string).unwrap_or(Value::Nil),
message.map(string).unwrap_or(Value::Nil),
]),
)
.await
}
async fn sqlite_call(
state: NativeState,
ctx: RpcCallContext,
method: u32,
payload: Value,
) -> Result<Value, RuntimeError> {
match method {
1 => {
let path = string_arg(&payload, 0)?;
let path = PathBuf::from(path);
state.policy.allow_sqlite_path(&path)?;
let db = rusqlite::Connection::open(path).map_err(sqlite_error)?;
let id = resource_id("sqlite");
state.sqlite.lock().expect("sqlite lock").insert(
id.clone(),
SqliteResource {
owner_connection_id: ctx.connection_id(),
connection: db,
},
);
Ok(array(vec![string(id)]))
}
2 => {
let id = string_arg(&payload, 0)?;
state.sqlite.lock().expect("sqlite lock").remove(&id);
Ok(empty())
}
3 => with_db(&state, &payload, |db, payload| {
db.execute_batch(&string_arg(payload, 1)?)
.map_err(sqlite_error)?;
Ok(empty())
}),
4 => with_db(&state, &payload, |db, payload| {
let changes = execute_with_params(db, &string_arg(payload, 1)?, field(payload, 2))?;
Ok(array(vec![
Value::from(changes as u64),
Value::from(db.last_insert_rowid()),
]))
}),
5 => with_db(&state, &payload, |db, payload| {
let rows = query_rows(db, &string_arg(payload, 1)?, field(payload, 2))?;
Ok(rows.into_iter().next().unwrap_or(Value::Nil))
}),
6 => with_db(&state, &payload, |db, payload| {
Ok(Value::Array(query_rows(
db,
&string_arg(payload, 1)?,
field(payload, 2),
)?))
}),
7 => with_db(&state, &payload, |db, payload| {
db.execute_batch("BEGIN IMMEDIATE").map_err(sqlite_error)?;
let statements = array_arg(payload, 1)?;
let result = (|| {
for statement in statements {
db.execute_batch(&value_string(statement)?)
.map_err(sqlite_error)?;
}
Ok::<_, RuntimeError>(())
})();
if result.is_ok() {
db.execute_batch("COMMIT").map_err(sqlite_error)?;
} else {
let _ = db.execute_batch("ROLLBACK");
}
result.map(|_| empty())
}),
_ => Err(method_not_found(method)),
}
}
fn with_db(
state: &NativeState,
payload: &Value,
op: impl FnOnce(&rusqlite::Connection, &Value) -> Result<Value, RuntimeError>,
) -> Result<Value, RuntimeError> {
let id = string_arg(payload, 0)?;
let guard = state.sqlite.lock().expect("sqlite lock");
let db = &guard
.get(&id)
.ok_or_else(|| runtime_error(format!("SQLite database `{id}` was not found")))?
.connection;
op(db, payload)
}
fn execute_with_params(
db: &rusqlite::Connection,
sql: &str,
params: Option<&Value>,
) -> Result<usize, RuntimeError> {
let values = sqlite_params(params)?;
db.execute(sql, rusqlite::params_from_iter(values))
.map_err(sqlite_error)
}
fn query_rows(
db: &rusqlite::Connection,
sql: &str,
params: Option<&Value>,
) -> Result<Vec<Value>, RuntimeError> {
let values = sqlite_params(params)?;
let mut statement = db.prepare(sql).map_err(sqlite_error)?;
let names: Vec<String> = statement
.column_names()
.into_iter()
.map(ToOwned::to_owned)
.collect();
let rows = statement
.query_map(rusqlite::params_from_iter(values), |row| {
let mut columns = Vec::new();
for (index, name) in names.iter().enumerate() {
columns.push(array(vec![string(name), sqlite_value(row.get_ref(index)?)]));
}
Ok(array(vec![Value::Array(columns)]))
})
.map_err(sqlite_error)?;
rows.collect::<Result<Vec<_>, _>>().map_err(sqlite_error)
}
fn sqlite_params(input: Option<&Value>) -> Result<Vec<rusqlite::types::Value>, RuntimeError> {
let Some(Value::Array(values)) = input else {
return Ok(Vec::new());
};
values.iter().map(sqlite_param).collect()
}
fn sqlite_param(value: &Value) -> Result<rusqlite::types::Value, RuntimeError> {
if let Value::Array(fields) = value {
let kind = string_arg(value, 0)?;
return match kind.as_str() {
"null" => Ok(rusqlite::types::Value::Null),
"integer" => Ok(rusqlite::types::Value::Integer(i64_arg_from_fields(
fields, 1,
)?)),
"real" => Ok(rusqlite::types::Value::Real(f64_arg_from_fields(
fields, 2,
)?)),
"text" => Ok(rusqlite::types::Value::Text(string_arg(value, 3)?)),
"blob" => Ok(rusqlite::types::Value::Blob(bytes_arg(value, 4)?)),
"boolean" => Ok(rusqlite::types::Value::Integer(i64::from(
bool_arg(value, 5).unwrap_or(false),
))),
_ => Err(decode_error(format!(
"unsupported SQLite tagged parameter kind `{kind}`"
))),
};
}
match value {
Value::Nil => Ok(rusqlite::types::Value::Null),
Value::Boolean(value) => Ok(rusqlite::types::Value::Integer(i64::from(*value))),
Value::Integer(value) => Ok(rusqlite::types::Value::Integer(
value
.as_i64()
.ok_or_else(|| decode_error("integer out of range"))?,
)),
Value::F32(value) => Ok(rusqlite::types::Value::Real((*value).into())),
Value::F64(value) => Ok(rusqlite::types::Value::Real(*value)),
Value::String(value) => Ok(rusqlite::types::Value::Text(
value.as_str().unwrap_or_default().to_string(),
)),
Value::Binary(value) => Ok(rusqlite::types::Value::Blob(value.clone())),
_ => Err(decode_error("unsupported SQLite parameter")),
}
}
fn sqlite_value(value: rusqlite::types::ValueRef<'_>) -> Value {
match value {
rusqlite::types::ValueRef::Null => array(vec![
string("null"),
Value::Nil,
Value::Nil,
Value::Nil,
Value::Nil,
Value::Nil,
]),
rusqlite::types::ValueRef::Integer(value) => array(vec![
string("integer"),
Value::from(value),
Value::Nil,
Value::Nil,
Value::Nil,
Value::Nil,
]),
rusqlite::types::ValueRef::Real(value) => array(vec![
string("real"),
Value::Nil,
Value::from(value),
Value::Nil,
Value::Nil,
Value::Nil,
]),
rusqlite::types::ValueRef::Text(value) => array(vec![
string("text"),
Value::Nil,
Value::Nil,
string(String::from_utf8_lossy(value)),
Value::Nil,
Value::Nil,
]),
rusqlite::types::ValueRef::Blob(value) => array(vec![
string("blob"),
Value::Nil,
Value::Nil,
Value::Nil,
Value::Binary(value.to_vec()),
Value::Nil,
]),
}
}
async fn system_call(
state: NativeState,
method: u32,
payload: Value,
) -> Result<Value, RuntimeError> {
match method {
1 => Ok(string_list(vec!["shutdown", "reboot"])),
2 => {
state.policy.allow_power("shutdown")?;
run_power_command("shutdown", payload).await
}
3 => {
state.policy.allow_power("reboot")?;
run_power_command("reboot", payload).await
}
_ => Err(method_not_found(method)),
}
}
async fn run_power_command(kind: &str, payload: Value) -> Result<Value, RuntimeError> {
let delay = u64_arg(&payload, 0).unwrap_or(0);
let mut command = if cfg!(target_os = "windows") {
let mut command = tokio::process::Command::new("shutdown");
command.arg(if kind == "reboot" { "/r" } else { "/s" });
command.arg("/t").arg(delay.to_string());
command
} else if cfg!(target_os = "macos") {
let mut command = tokio::process::Command::new("osascript");
command.arg("-e").arg(if kind == "reboot" {
"tell app \"System Events\" to restart"
} else {
"tell app \"System Events\" to shut down"
});
command
} else {
let mut command = tokio::process::Command::new("systemctl");
if delay > 0 {
return Err(runtime_error(
"delayed shutdown/reboot is not supported by the Linux provider",
));
}
command.arg(if kind == "reboot" {
"reboot"
} else {
"poweroff"
});
command
};
command
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null());
command.spawn().map_err(io_error)?;
Ok(empty())
}
fn empty() -> Value {
Value::Array(Vec::new())
}
fn array(values: Vec<Value>) -> Value {
Value::Array(values)
}
fn string(value: impl AsRef<str>) -> Value {
Value::String(value.as_ref().to_string().into())
}
fn string_list(values: Vec<impl AsRef<str>>) -> Value {
Value::Array(values.into_iter().map(string).collect())
}
fn field(value: &Value, index: usize) -> Option<&Value> {
match value {
Value::Array(values) => values.get(index),
_ => None,
}
}
fn array_arg(value: &Value, index: usize) -> Result<&[Value], RuntimeError> {
match field(value, index) {
Some(Value::Array(values)) => Ok(values),
_ => Err(decode_error(format!("field {index} must be an array"))),
}
}
fn string_arg(value: &Value, index: usize) -> Result<String, RuntimeError> {
field(value, index)
.map(value_string)
.transpose()?
.ok_or_else(|| decode_error(format!("field {index} must be a string")))
}
fn value_string(value: &Value) -> Result<String, RuntimeError> {
match value {
Value::String(value) => Ok(value.as_str().unwrap_or_default().to_string()),
_ => Err(decode_error("expected string")),
}
}
fn path_arg(value: &Value, index: usize) -> Result<PathBuf, RuntimeError> {
Ok(PathBuf::from(string_arg(value, index)?))
}
fn bytes_arg(value: &Value, index: usize) -> Result<Vec<u8>, RuntimeError> {
match field(value, index) {
Some(Value::Binary(value)) => Ok(value.clone()),
Some(Value::String(value)) => Ok(value.as_str().unwrap_or_default().as_bytes().to_vec()),
_ => Err(decode_error(format!("field {index} must be bytes"))),
}
}
fn bool_arg(value: &Value, index: usize) -> Option<bool> {
match field(value, index) {
Some(Value::Boolean(value)) => Some(*value),
_ => None,
}
}
fn bool_field(fields: &[Value], index: usize) -> Option<bool> {
match fields.get(index) {
Some(Value::Boolean(value)) => Some(*value),
_ => None,
}
}
fn u16_arg(value: &Value, index: usize) -> Result<u16, RuntimeError> {
let value = u64_arg(value, index)?;
u16::try_from(value).map_err(|_| decode_error(format!("field {index} is out of u16 range")))
}
fn i64_arg(value: &Value, index: usize) -> Result<i64, RuntimeError> {
match field(value, index) {
Some(Value::Integer(value)) => value
.as_i64()
.ok_or_else(|| decode_error(format!("field {index} must be an integer"))),
_ => Err(decode_error(format!("field {index} must be an integer"))),
}
}
fn u64_arg(value: &Value, index: usize) -> Result<u64, RuntimeError> {
match field(value, index) {
Some(Value::Integer(value)) => value
.as_u64()
.ok_or_else(|| decode_error(format!("field {index} must be a non-negative integer"))),
_ => Err(decode_error(format!("field {index} must be an integer"))),
}
}
fn i64_arg_from_fields(fields: &[Value], index: usize) -> Result<i64, RuntimeError> {
match fields.get(index) {
Some(Value::Integer(value)) => value
.as_i64()
.ok_or_else(|| decode_error(format!("field {index} must be an integer"))),
_ => Err(decode_error(format!("field {index} must be an integer"))),
}
}
fn f64_arg_from_fields(fields: &[Value], index: usize) -> Result<f64, RuntimeError> {
match fields.get(index) {
Some(Value::F64(value)) => Ok(*value),
Some(Value::F32(value)) => Ok((*value).into()),
Some(Value::Integer(value)) => value
.as_f64()
.ok_or_else(|| decode_error(format!("field {index} must be a number"))),
_ => Err(decode_error(format!("field {index} must be a number"))),
}
}
fn resource_id(prefix: &str) -> String {
format!("{prefix}-{}", Uuid::new_v4())
}
async fn open_file(
state: &NativeState,
id: &str,
) -> Result<Arc<AsyncMutex<tokio::fs::File>>, RuntimeError> {
state
.open_files
.lock()
.await
.get(id)
.map(|resource| resource.file.clone())
.ok_or_else(|| runtime_error(format!("file resource `{id}` was not found")))
}
fn is_fs_write_operation(operation: &str) -> bool {
matches!(
operation,
"write_file"
| "append_file"
| "mkdir"
| "remove"
| "rename_to"
| "copy_to"
| "open_file_write"
| "archive_zip_write"
| "archive_unzip_write"
)
}
fn path_is_under(path: &Path, root: &Path) -> bool {
let Ok(path) = policy_path(path) else {
return false;
};
let Ok(root) = policy_path(root) else {
return false;
};
path == root || path.starts_with(root)
}
fn policy_path(path: &Path) -> Result<PathBuf, std::io::Error> {
if path.exists() {
return path.canonicalize();
}
let absolute = if path.is_absolute() {
path.to_path_buf()
} else {
std::env::current_dir()?.join(path)
};
let Some(parent) = absolute.parent() else {
return Ok(absolute);
};
let parent = if parent.exists() {
parent.canonicalize()?
} else {
parent.to_path_buf()
};
Ok(match absolute.file_name() {
Some(name) => parent.join(name),
None => parent,
})
}
fn websocket_url_authority(url: &str) -> Result<(String, u16), RuntimeError> {
let (default_port, rest) = if let Some(rest) = url.strip_prefix("ws://") {
(80, rest)
} else if let Some(rest) = url.strip_prefix("wss://") {
(443, rest)
} else {
return Err(decode_error(
"websocket URL must start with ws:// or wss://",
));
};
let authority = rest
.split(['/', '?', '#'])
.next()
.filter(|value| !value.is_empty())
.ok_or_else(|| decode_error("websocket URL must include a host"))?;
let (host, port) = if authority.starts_with('[') {
let end = authority
.find(']')
.ok_or_else(|| decode_error("websocket URL has invalid IPv6 host"))?;
let host = authority[1..end].to_string();
let port = authority[end + 1..]
.strip_prefix(':')
.map(parse_port)
.transpose()?
.unwrap_or(default_port);
(host, port)
} else {
match authority.rsplit_once(':') {
Some((host, port)) => (host.to_string(), parse_port(port)?),
None => (authority.to_string(), default_port),
}
};
if host.is_empty() {
return Err(decode_error("websocket URL must include a host"));
}
Ok((host, port))
}
fn parse_port(value: &str) -> Result<u16, RuntimeError> {
value
.parse()
.map_err(|_| decode_error("websocket URL port must be a valid u16"))
}
fn method_not_found(method: u32) -> RuntimeError {
RuntimeError::runtime(
RuntimeErrorCode::MethodNotFound,
format!("native method id `{method}` was not found"),
)
}
fn decode_error(message: impl Into<String>) -> RuntimeError {
RuntimeError::runtime(RuntimeErrorCode::PayloadDecodeFailed, message)
}
fn runtime_error(message: impl Into<String>) -> RuntimeError {
RuntimeError::runtime(RuntimeErrorCode::InternalRuntimeError, message)
}
fn policy_denied(scope: &str, operation: &str) -> RuntimeError {
RuntimeError::runtime(
RuntimeErrorCode::AccessDenied,
format!("native policy denied {scope} operation `{operation}`"),
)
}
fn io_error(error: std::io::Error) -> RuntimeError {
runtime_error(error.to_string())
}
fn ws_error(error: tokio_tungstenite::tungstenite::Error) -> RuntimeError {
runtime_error(error.to_string())
}
fn sqlite_error(error: rusqlite::Error) -> RuntimeError {
runtime_error(error.to_string())
}
#[allow(dead_code)]
fn parse_addr(host: &str, port: u16) -> Result<SocketAddr, RuntimeError> {
format!("{host}:{port}")
.parse()
.map_err(|error| runtime_error(format!("invalid socket address: {error}")))
}