#![cfg_attr(fbcode_build, deny(warnings))]
pub mod expr;
pub mod fields;
mod named_pipe;
pub mod pdu;
use std::collections::HashMap;
use std::collections::VecDeque;
use std::io;
use std::io::Write;
use std::marker::PhantomData;
use std::path::Path;
use std::path::PathBuf;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use bytes::Bytes;
use bytes::BytesMut;
use futures::future::FutureExt;
use futures::stream::StreamExt;
use serde_bser::de::Bunser;
use serde_bser::de::SliceRead;
pub use serde_bser::value::Value;
use thiserror::Error;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::io::AsyncWriteExt;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::process::Command;
use tokio::sync::mpsc::Receiver;
use tokio::sync::mpsc::Sender;
use tokio::sync::mpsc::UnboundedReceiver;
use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::Mutex;
use tokio_util::codec::Decoder;
use tokio_util::codec::FramedRead;
static SUB_ID: AtomicUsize = AtomicUsize::new(1);
pub mod prelude {
pub use crate::expr::*;
pub use crate::fields::*;
pub use crate::pdu::*;
pub use crate::query_result_type;
pub use crate::CanonicalPath;
pub use crate::Client;
pub use crate::Connector;
pub use crate::ResolvedRoot;
}
use prelude::*;
#[derive(Error, Debug)]
pub enum ConnectionLost {
#[error("Client task exited")]
ClientTaskExited,
#[error("Client task failed: {0}")]
Error(String),
}
#[derive(Error, Debug)]
pub enum Error {
#[error("Failed to connect to Watchman: {0}")]
ConnectionError(tokio::io::Error),
#[error("Lost connection to watchman")]
ConnectionLost(#[from] ConnectionLost),
#[error(
"While invoking the {watchman_path} CLI to discover the server connection details: {reason}, stderr=`{stderr}`"
)]
ConnectionDiscovery {
watchman_path: PathBuf,
reason: String,
stderr: String,
},
#[error("The watchman server reported an error: \"{}\", while executing command: {}", .message, .command)]
WatchmanServerError { message: String, command: String },
#[error("The watchman server reported an error: \"{}\"", .message)]
WatchmanResponseError { message: String },
#[error("The watchman server didn't return a value for field `{}` in response to a `{}` command. {:?}", .fieldname, .command, .response)]
MissingField {
fieldname: &'static str,
command: String,
response: String,
},
#[error("Deserialization error (data: {data:x?})")]
Deserialize {
data: Vec<u8>,
#[source]
source: anyhow::Error,
},
#[error("Seriaization error")]
Serialize {
#[source]
source: anyhow::Error,
},
#[error("Failed to connect to {endpoint}")]
Connect {
endpoint: PathBuf,
#[source]
source: Box<std::io::Error>,
},
}
#[derive(Error, Debug)]
enum TaskError {
#[error("IO Error: {0}")]
Io(#[from] std::io::Error),
#[error("Task is shutting down")]
Shutdown,
#[error("EOF on Watchman socket")]
Eof,
#[error("Received a unilateral PDU from the server")]
UnilateralPdu,
#[error("Deserialization error (data: {data:x?})")]
Deserialize {
#[source]
source: anyhow::Error,
data: Vec<u8>,
},
}
#[derive(Default)]
pub struct Connector {
watchman_cli_path: Option<PathBuf>,
unix_domain: Option<PathBuf>,
}
impl Connector {
pub fn new() -> Self {
let connector = Self::default();
if let Some(val) = std::env::var_os("WATCHMAN_SOCK") {
connector.unix_domain_socket(val)
} else {
connector
}
}
pub fn watchman_cli_path<P: AsRef<Path>>(mut self, path: P) -> Self {
self.watchman_cli_path = Some(path.as_ref().to_path_buf());
self
}
pub fn unix_domain_socket<P: AsRef<Path>>(mut self, path: P) -> Self {
self.unix_domain = Some(path.as_ref().to_path_buf());
self
}
async fn resolve_unix_domain_path(&self) -> Result<PathBuf, Error> {
if let Some(path) = self.unix_domain.as_ref() {
Ok(path.clone())
} else {
let watchman_path = self
.watchman_cli_path
.as_ref()
.map_or_else(|| Path::new("watchman"), |p| p.as_ref());
let mut cmd = Command::new(watchman_path);
cmd.args(["--output-encoding", "bser-v2", "get-sockname"]);
#[cfg(windows)]
cmd.creation_flags(winapi::um::winbase::CREATE_NO_WINDOW);
let output = cmd
.output()
.await
.map_err(|source| Error::ConnectionDiscovery {
watchman_path: watchman_path.to_path_buf(),
reason: source.to_string(),
stderr: "".to_string(),
})?;
let info: GetSockNameResponse =
serde_bser::from_slice(&output.stdout).map_err(|source| {
Error::ConnectionDiscovery {
watchman_path: watchman_path.to_path_buf(),
reason: source.to_string(),
stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
}
})?;
let debug = format!("{:#?}", info);
if let Some(message) = info.error {
return Err(Error::WatchmanServerError {
message,
command: "get-sockname".into(),
});
}
info.sockname.ok_or_else(|| Error::MissingField {
fieldname: "sockname",
command: "get-sockname".into(),
response: debug,
})
}
}
pub async fn connect(&self) -> Result<Client, Error> {
let sock_path = self.resolve_unix_domain_path().await?;
#[cfg(unix)]
let stream = UnixStream::connect(sock_path)
.await
.map_err(Error::ConnectionError)?;
#[cfg(windows)]
let stream = named_pipe::NamedPipe::connect(sock_path).await?;
let stream: Box<dyn ReadWriteStream> = Box::new(stream);
let (reader, writer) = tokio::io::split(stream);
let (request_tx, request_rx) = tokio::sync::mpsc::channel(128);
let mut task = ClientTask {
writer,
reader: FramedRead::new(reader, BserSplitter),
request_rx,
request_queue: VecDeque::new(),
waiting_response: false,
subscriptions: HashMap::new(),
};
tokio::spawn(async move {
if let Err(err) = task.run().await {
let _ignored = writeln!(io::stderr(), "watchman client task failed: {}", err);
}
});
let inner = Arc::new(Mutex::new(ClientInner { request_tx }));
Ok(Client { inner })
}
}
#[derive(Debug, Clone)]
pub struct CanonicalPath(PathBuf);
impl CanonicalPath {
pub fn canonicalize<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
let path = std::fs::canonicalize(path)?;
Ok(Self(Self::strip_unc_escape(path)))
}
pub fn with_canonicalized_path(path: PathBuf) -> Self {
assert!(
path.is_absolute(),
"attempted to call \
CanonicalPath::with_canonicalized_path on a non-canonical path! \
You probably want to call CanonicalPath::canonicalize instead!"
);
Self(Self::strip_unc_escape(path))
}
#[cfg(windows)]
#[inline]
fn strip_unc_escape(path: PathBuf) -> PathBuf {
match path.to_str() {
Some(s) if s.starts_with("\\\\?\\") => PathBuf::from(&s[4..]),
_ => path,
}
}
#[cfg(unix)]
#[inline]
fn strip_unc_escape(path: PathBuf) -> PathBuf {
path
}
pub fn into_path_buf(self) -> PathBuf {
self.0
}
}
#[derive(Debug, Clone)]
pub struct ResolvedRoot {
root: PathBuf,
relative: Option<PathBuf>,
watcher: String,
}
impl ResolvedRoot {
pub fn watcher(&self) -> &str {
self.watcher.as_str()
}
pub fn project_root(&self) -> &Path {
&self.root
}
pub fn path(&self) -> PathBuf {
if let Some(relative) = self.relative.as_ref() {
self.root.join(relative)
} else {
self.root.clone()
}
}
pub fn project_relative_path(&self) -> Option<&Path> {
self.relative.as_ref().map(PathBuf::as_ref)
}
}
trait ReadWriteStream: AsyncRead + AsyncWrite + std::marker::Unpin + Send {}
#[cfg(unix)]
impl ReadWriteStream for UnixStream {}
struct SendRequest {
buf: Vec<u8>,
tx: tokio::sync::oneshot::Sender<Result<Bytes, String>>,
}
impl SendRequest {
fn respond(self, result: Result<Bytes, String>) {
let _ = self.tx.send(result);
}
}
enum SubscriptionNotification {
Pdu(Bytes),
Canceled,
}
enum TaskItem {
QueueRequest(SendRequest),
RegisterSubscription(String, UnboundedSender<SubscriptionNotification>),
}
struct BserSplitter;
impl Decoder for BserSplitter {
type Item = Bytes;
type Error = TaskError;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let mut bunser = Bunser::new(SliceRead::new(buf.as_ref()));
let pdu = match bunser.read_pdu() {
Ok(pdu) => pdu,
Err(source) => {
const BUF_SIZE: usize = 16;
let missing = BUF_SIZE.saturating_sub(buf.len());
if missing > 0 {
buf.reserve(missing);
return Ok(None);
}
return Err(TaskError::Deserialize {
source: source.into(),
data: buf.to_vec(),
});
}
};
let total_size = (pdu.start + pdu.len) as usize;
let missing = total_size.saturating_sub(buf.len());
if missing > 0 {
buf.reserve(missing);
return Ok(None);
}
let ret = buf.split_to(total_size);
Ok(Some(ret.freeze()))
}
}
pub struct Client {
inner: Arc<Mutex<ClientInner>>,
}
struct ClientTask {
writer: tokio::io::WriteHalf<Box<dyn ReadWriteStream>>,
reader: FramedRead<tokio::io::ReadHalf<Box<dyn ReadWriteStream>>, BserSplitter>,
request_rx: Receiver<TaskItem>,
request_queue: VecDeque<SendRequest>,
waiting_response: bool,
subscriptions: HashMap<String, UnboundedSender<SubscriptionNotification>>,
}
impl Drop for ClientTask {
fn drop(&mut self) {
self.fail_all(&TaskError::Shutdown)
}
}
impl ClientTask {
async fn run(&mut self) -> Result<(), TaskError> {
match self.run_loop().await {
Err(err) => {
self.fail_all(&err);
Err(err)
}
ok => ok,
}
}
async fn run_loop(&mut self) -> Result<(), TaskError> {
loop {
futures::select_biased! {
pdu = self.reader.next().fuse() => {
match pdu {
Some(pdu) => self.process_pdu(pdu?).await?,
None => return Err(TaskError::Eof),
}
}
task = self.request_rx.recv().fuse() => {
match task {
Some(TaskItem::QueueRequest(request)) => self.queue_request(request).await?,
Some(TaskItem::RegisterSubscription(name, tx)) => {
self.register_subscription(name, tx)
}
None => break,
}
}
}
}
Ok(())
}
fn register_subscription(
&mut self,
name: String,
tx: UnboundedSender<SubscriptionNotification>,
) {
self.subscriptions.insert(name, tx);
}
fn fail_all(&mut self, err: &TaskError) {
while let Some(request) = self.request_queue.pop_front() {
request.respond(Err(err.to_string()));
}
}
async fn send_next_request(&mut self) -> Result<(), TaskError> {
if !self.waiting_response && !self.request_queue.is_empty() {
match self
.writer
.write_all(&self.request_queue.front().expect("not empty").buf)
.await
{
Err(err) => {
return Err(err.into());
}
Ok(_) => self.waiting_response = true,
}
}
Ok(())
}
async fn queue_request(&mut self, request: SendRequest) -> Result<(), TaskError> {
self.request_queue.push_back(request);
self.send_next_request().await?;
Ok(())
}
async fn process_pdu(&mut self, pdu: Bytes) -> Result<(), TaskError> {
use serde::Deserialize;
#[derive(Deserialize, Debug)]
pub struct Unilateral {
#[allow(unused)]
pub unilateral: bool,
pub subscription: String,
#[serde(default)]
pub canceled: bool,
}
if let Ok(unilateral) = bunser::<Unilateral>(&pdu) {
if let Some(subscription) = self.subscriptions.get_mut(&unilateral.subscription) {
let msg = if unilateral.canceled {
SubscriptionNotification::Canceled
} else {
SubscriptionNotification::Pdu(pdu)
};
if subscription.send(msg).is_err() || unilateral.canceled {
self.subscriptions.remove(&unilateral.subscription);
}
}
} else if self.waiting_response {
let request = self
.request_queue
.pop_front()
.expect("waiting_response is only true when request_queue is not empty");
self.waiting_response = false;
request.respond(Ok(pdu));
} else {
return Err(TaskError::UnilateralPdu);
}
self.send_next_request().await?;
Ok(())
}
}
fn bunser<T>(buf: &[u8]) -> Result<T, Error>
where
T: serde::de::DeserializeOwned,
{
let response: T = serde_bser::from_slice(buf).map_err(|source| Error::Deserialize {
source: source.into(),
data: buf.to_vec(),
})?;
Ok(response)
}
struct ClientInner {
request_tx: Sender<TaskItem>,
}
impl ClientInner {
pub(crate) async fn generic_request<Request, Response>(
&mut self,
request: Request,
) -> Result<Response, Error>
where
Request: serde::Serialize + std::fmt::Debug,
Response: serde::de::DeserializeOwned,
{
let mut request_data = vec![];
serde_bser::ser::serialize(&mut request_data, &request).map_err(|source| {
Error::Serialize {
source: source.into(),
}
})?;
let (tx, rx) = tokio::sync::oneshot::channel();
self.request_tx
.send(TaskItem::QueueRequest(SendRequest {
buf: request_data,
tx,
}))
.await
.map_err(|_| ConnectionLost::ClientTaskExited)?;
let pdu_data = rx
.await
.map_err(|_| ConnectionLost::ClientTaskExited)?
.map_err(ConnectionLost::Error)?;
use serde::Deserialize;
#[derive(Deserialize, Debug)]
struct MaybeError {
#[serde(default)]
error: Option<String>,
}
let maybe_err: MaybeError = bunser(&pdu_data)?;
if let Some(message) = maybe_err.error {
return Err(Error::WatchmanServerError {
message,
command: format!("{:#?}", request),
});
}
let response: Response = bunser(&pdu_data)?;
Ok(response)
}
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug, Clone)]
pub enum SubscriptionData<F>
where
F: serde::de::DeserializeOwned + std::fmt::Debug + Clone + QueryFieldList,
{
Canceled,
FilesChanged(QueryResult<F>),
StateEnter {
state_name: String,
metadata: Option<Value>,
},
StateLeave {
state_name: String,
metadata: Option<Value>,
},
}
pub struct Subscription<F>
where
F: serde::de::DeserializeOwned + std::fmt::Debug + Clone + QueryFieldList,
{
name: String,
inner: Arc<Mutex<ClientInner>>,
root: ResolvedRoot,
responses: UnboundedReceiver<SubscriptionNotification>,
_phantom: PhantomData<F>,
}
impl<F> Subscription<F>
where
F: serde::de::DeserializeOwned + std::fmt::Debug + Clone + QueryFieldList,
{
pub fn name(&self) -> &str {
&self.name
}
#[allow(clippy::should_implement_trait)]
pub async fn next(&mut self) -> Result<SubscriptionData<F>, Error> {
let msg = self
.responses
.recv()
.await
.ok_or(ConnectionLost::ClientTaskExited)?;
match msg {
SubscriptionNotification::Pdu(pdu) => {
let response: QueryResult<F> = bunser(&pdu)?;
if let Some(state_name) = response.state_enter {
Ok(SubscriptionData::StateEnter {
state_name,
metadata: response.state_metadata,
})
} else if let Some(state_name) = response.state_leave {
Ok(SubscriptionData::StateLeave {
state_name,
metadata: response.state_metadata,
})
} else {
Ok(SubscriptionData::FilesChanged(response))
}
}
SubscriptionNotification::Canceled => {
self.responses.close();
Ok(SubscriptionData::Canceled)
}
}
}
pub async fn cancel(self) -> Result<(), Error> {
let mut inner = self.inner.lock().await;
let _: UnsubscribeResponse = inner
.generic_request(Unsubscribe("unsubscribe", self.root.root, self.name))
.await?;
Ok(())
}
}
impl Client {
#[doc(hidden)]
pub async fn generic_request<Request, Response>(
&self,
request: Request,
) -> Result<Response, Error>
where
Request: serde::Serialize + std::fmt::Debug,
Response: serde::de::DeserializeOwned,
{
let mut inner = self.inner.lock().await;
let response: Response = inner.generic_request(request).await?;
Ok(response)
}
pub async fn version(&self) -> Result<GetVersionResponse, Error> {
self.generic_request(&["version"]).await
}
pub async fn watch_list(&self) -> Result<WatchListResponse, Error> {
self.generic_request(&["watch-list"]).await
}
pub async fn state_enter(
&self,
root: &ResolvedRoot,
state_name: &str,
sync_timeout: SyncTimeout,
metadata: Option<Value>,
) -> Result<(), Error> {
let request = StateEnterLeaveRequest(
"state-enter",
root.root.clone(),
StateEnterLeaveParams {
name: state_name,
metadata,
sync_timeout,
},
);
let _response: StateEnterLeaveResponse = self.generic_request(request).await?;
Ok(())
}
pub async fn state_leave(
&self,
root: &ResolvedRoot,
state_name: &str,
sync_timeout: SyncTimeout,
metadata: Option<Value>,
) -> Result<(), Error> {
let request = StateEnterLeaveRequest(
"state-leave",
root.root.clone(),
StateEnterLeaveParams {
name: state_name,
metadata,
sync_timeout,
},
);
let _response: StateEnterLeaveResponse = self.generic_request(request).await?;
Ok(())
}
pub async fn resolve_root(&self, path: CanonicalPath) -> Result<ResolvedRoot, Error> {
let response: WatchProjectResponse = self
.generic_request(WatchProjectRequest("watch-project", path.0.clone()))
.await?;
Ok(ResolvedRoot {
root: response.watch,
relative: response.relative_path,
watcher: response.watcher,
})
}
pub async fn query<F>(
&self,
root: &ResolvedRoot,
query: QueryRequestCommon,
) -> Result<QueryResult<F>, Error>
where
F: serde::de::DeserializeOwned + std::fmt::Debug + Clone + QueryFieldList,
{
let query = QueryRequest(
"query",
root.root.clone(),
QueryRequestCommon {
relative_root: root.relative.clone(),
fields: F::field_list(),
..query
},
);
let response: QueryResult<F> = self.generic_request(query.clone()).await?;
Ok(response)
}
pub async fn subscribe<F>(
&self,
root: &ResolvedRoot,
query: SubscribeRequest,
) -> Result<(Subscription<F>, SubscribeResponse), Error>
where
F: serde::de::DeserializeOwned + std::fmt::Debug + Clone + QueryFieldList,
{
let name = format!(
"sub-[{}]-{}",
std::env::args()
.next()
.unwrap_or_else(|| "<no-argv-0>".to_string()),
SUB_ID.fetch_add(1, Ordering::Relaxed)
);
let query = SubscribeCommand(
"subscribe",
root.root.clone(),
name.clone(),
SubscribeRequest {
relative_root: root.relative.clone(),
fields: F::field_list(),
..query
},
);
let (tx, responses) = tokio::sync::mpsc::unbounded_channel();
{
let inner = self.inner.lock().await;
inner
.request_tx
.send(TaskItem::RegisterSubscription(name.clone(), tx))
.await
.map_err(|_| ConnectionLost::ClientTaskExited)?;
}
let subscription = Subscription::<F> {
name,
inner: Arc::clone(&self.inner),
root: root.clone(),
responses,
_phantom: PhantomData,
};
let response: SubscribeResponse = self.generic_request(query).await?;
Ok((subscription, response))
}
pub async fn glob(&self, root: &ResolvedRoot, globs: &[&str]) -> Result<Vec<PathBuf>, Error> {
let response: QueryResult<NameOnly> = self
.query(
root,
QueryRequestCommon {
relative_root: root.relative.clone(),
glob: Some(globs.iter().map(|&s| s.to_string()).collect()),
..Default::default()
},
)
.await?;
Ok(response
.files
.unwrap_or_else(Vec::new)
.into_iter()
.map(|f| f.name.into_inner())
.collect())
}
pub async fn clock(
&self,
root: &ResolvedRoot,
sync_timeout: SyncTimeout,
) -> Result<ClockSpec, Error> {
let response: ClockResponse = self
.generic_request(ClockRequest(
"clock",
root.root.clone(),
ClockRequestParams { sync_timeout },
))
.await?;
Ok(response.clock)
}
pub async fn get_config(&self, root: &ResolvedRoot) -> Result<WatchmanConfig, Error> {
let response: GetConfigResponse = self
.generic_request(GetConfigRequest("get-config", root.root.clone()))
.await?;
Ok(response.config)
}
pub async fn register_trigger(
&self,
root: &ResolvedRoot,
request: TriggerRequest,
) -> Result<TriggerResponse, Error> {
let response: TriggerResponse = self
.generic_request(TriggerCommand("trigger", root.root.clone(), request))
.await?;
Ok(response)
}
pub async fn remove_trigger(
&self,
root: &ResolvedRoot,
name: &str,
) -> Result<TriggerDelResponse, Error> {
let response: TriggerDelResponse = self
.generic_request(TriggerDelCommand(
"trigger-del",
root.root.clone(),
name.into(),
))
.await?;
Ok(response)
}
pub async fn list_triggers(&self, root: &ResolvedRoot) -> Result<TriggerListResponse, Error> {
let response: TriggerListResponse = self
.generic_request(TriggerListCommand("trigger-list", root.root.clone()))
.await?;
Ok(response)
}
}
#[cfg(test)]
mod tests {
use std::io;
use futures::stream;
use futures::stream::TryStreamExt;
use serde::Deserialize;
use serde::Serialize;
use tokio_util::io::StreamReader;
use super::*;
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct TestStruct {
value: i32,
}
#[test]
fn connection_builder_paths() {
let builder = Connector::new().unix_domain_socket("/some/path");
assert_eq!(builder.unix_domain, Some(PathBuf::from("/some/path")));
}
#[tokio::test]
async fn test_decoder() {
async fn read_bser(buf: &[u8], chunk_size: usize) -> Vec<TestStruct> {
let chunks = buf
.chunks(chunk_size)
.map(|c| Result::<_, io::Error>::Ok(Bytes::copy_from_slice(c)));
let reader = StreamReader::new(stream::iter(chunks));
FramedRead::new(reader, BserSplitter)
.map_err(TaskError::from)
.and_then(|bytes| async move {
Ok(serde_bser::from_slice::<TestStruct>(&bytes).unwrap())
})
.try_collect()
.await
.unwrap()
}
let msgs = vec![
TestStruct { value: 1 },
TestStruct { value: 2 },
TestStruct { value: 3 },
];
let mut buf = vec![];
for msg in msgs.iter() {
serde_bser::ser::serialize(&mut buf, msg).expect("Failed to write to a Vec");
}
assert_eq!(msgs, read_bser(&buf, 1).await);
assert_eq!(msgs, read_bser(&buf, 2).await);
assert_eq!(msgs, read_bser(&buf, 10).await);
assert_eq!(msgs, read_bser(&buf, buf.len()).await);
}
#[test]
fn test_decoder_err() {
let mut bytes = BytesMut::new();
bytes.extend_from_slice(&[0; 10]);
let r1 = BserSplitter.decode(&mut bytes);
assert!(r1.is_ok());
assert!(r1.unwrap().is_none());
bytes.extend_from_slice(&[0; 10]);
let r1 = BserSplitter.decode(&mut bytes);
assert!(r1.is_err());
}
#[test]
fn test_bounds() {
fn assert_bounds<T: std::error::Error + Sync + Send + 'static>() {}
assert_bounds::<Error>();
assert_bounds::<TaskError>();
}
}