use std::path::{Component, Path};
use tokio_stream::StreamExt;
use builder_pattern::Builder;
use serde::{de::DeserializeOwned, Deserialize};
use tonic::{
transport::{Certificate, Channel, ClientTlsConfig, Endpoint, Identity, Uri},
IntoRequest,
};
use proto::*;
mod proto;
#[allow(dead_code)]
#[derive(Deserialize)]
pub struct ClientConfig {
ca_certificate: String,
client_cert: String,
client_private_key: String,
api_connection_string: String,
name: String,
}
impl ClientConfig {
pub fn from_yaml_file<P: AsRef<Path>>(path: &P) -> Result<Self, Box<dyn std::error::Error>> {
let cc = serde_yaml::from_reader(std::fs::File::open(path)?)?;
Ok(cc)
}
fn tls_config(&self) -> ClientTlsConfig {
let ca = Certificate::from_pem(self.ca_certificate.clone());
let id = Identity::from_pem(self.client_cert.clone(), self.client_private_key.clone());
ClientTlsConfig::new()
.domain_name("VelociraptorServer")
.ca_certificate(ca)
.identity(id)
}
}
pub struct Client {
endpoint: Endpoint,
}
impl TryFrom<&ClientConfig> for Client {
type Error = Box<dyn std::error::Error>;
fn try_from(cfg: &ClientConfig) -> Result<Self, Self::Error> {
let uri = Uri::builder()
.scheme("https")
.authority(cfg.api_connection_string.as_str())
.path_and_query("/")
.build()?;
let endpoint = Endpoint::from(uri).tls_config(cfg.tls_config())?;
Ok(Self { endpoint })
}
}
#[derive(Builder)]
pub struct QueryOptions {
#[public]
#[into]
env: Vec<(String, String)>,
#[public]
#[into]
org_id: Option<String>,
#[public]
#[default(10)]
max_row: u64,
}
impl Client {
async fn api_client(
&self,
) -> Result<api_client::ApiClient<Channel>, Box<dyn std::error::Error>> {
Ok(api_client::ApiClient::new(self.endpoint.connect().await?))
}
pub async fn query<T: DeserializeOwned>(
&self,
query: &str,
options: &QueryOptions,
) -> Result<Vec<T>, Box<dyn std::error::Error>> {
let env = options
.env
.iter()
.cloned()
.map(|(key, value)| VqlEnv { key, value })
.collect::<Vec<_>>();
let org_id = options.org_id.clone().unwrap_or_default();
let query = vec![VqlRequest {
name: "".into(),
vql: query.into(),
}];
let max_row = options.max_row;
let mut response = self
.api_client()
.await?
.query(
VqlCollectorArgs {
env,
org_id,
max_row,
query,
..VqlCollectorArgs::default()
}
.into_request(),
)
.await?
.into_inner();
let mut result = vec![];
while let Some(Ok(msg)) = response.next().await {
if !msg.response.is_empty() {
log::trace!("result = {}", &msg.response);
result.append(&mut serde_json::from_str(&msg.response)?);
}
if !msg.log.is_empty() {
log::debug!("log = {}", msg.log.to_string().trim());
if msg.log.starts_with("VQL Error:") {
return Err(msg.log.into());
}
}
}
Ok(result)
}
pub async fn fetch<P: AsRef<Path>>(
&self,
path: P,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
let components: Vec<_> = path
.as_ref()
.components()
.filter_map(|c| match c {
Component::Normal(s) => Some(s.to_string_lossy().to_string()),
_ => None,
})
.collect();
let request = VfsFileBuffer {
components,
length: 1024,
..VfsFileBuffer::default()
};
let mut api_client = self.api_client().await?;
let (mut buf, mut offset) = (vec![], 0);
loop {
let response = api_client
.vfs_get_buffer(
VfsFileBuffer {
offset,
..request.clone()
}
.into_request(),
)
.await?
.into_inner();
match response.data.len() {
0 => break,
len => {
buf.extend(response.data);
offset += len as u64;
}
};
}
Ok(buf)
}
}