use std::time::Duration;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use crate::context::ExecutionContext;
use crate::error::ToolError;
use crate::registry::{Tool, ToolConfig};
use crate::result::ToolResult;
use crate::template::TemplateEngine;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResultFetchConfig {
pub r#ref: String,
#[serde(default = "default_prefer")]
pub prefer: BackendPreference,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub flight_endpoint: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub bearer_token: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tls_ca_path: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub client_cert_path: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub client_key_path: Option<String>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum BackendPreference {
Flight,
Http,
}
fn default_prefer() -> BackendPreference {
BackendPreference::Flight
}
pub struct ResultFetchTool {
http_client: reqwest::Client,
template_engine: TemplateEngine,
}
impl ResultFetchTool {
pub fn new() -> Self {
Self {
http_client: reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()
.unwrap_or_default(),
template_engine: TemplateEngine::new(),
}
}
fn parse_config(
&self,
config: &ToolConfig,
ctx: &ExecutionContext,
) -> Result<ResultFetchConfig, ToolError> {
let template_ctx = ctx.to_template_context();
let rendered = self
.template_engine
.render_value(&config.config, &template_ctx)?;
serde_json::from_value(rendered)
.map_err(|e| ToolError::Configuration(format!("Invalid result_fetch config: {}", e)))
}
fn derive_flight_endpoint(server_url: &str) -> String {
let (scheme, trimmed) = if let Some(rest) = server_url.strip_prefix("https://") {
("https", rest)
} else if let Some(rest) = server_url.strip_prefix("http://") {
("http", rest)
} else {
("http", server_url)
};
let rewritten = if let Some(stripped) = trimmed.strip_suffix(":8082") {
format!("{stripped}:8083")
} else {
trimmed.to_string()
};
format!("{scheme}://{rewritten}")
}
async fn fetch_via_http(
&self,
cfg: &ResultFetchConfig,
ctx: &ExecutionContext,
) -> Result<JsonValue, ToolError> {
let url = format!(
"{}/api/result/resolve",
ctx.server_url.trim_end_matches('/')
);
let response = self
.http_client
.get(&url)
.query(&[("ref", cfg.r#ref.as_str())])
.send()
.await
.map_err(|e| ToolError::Http(format!("HTTP fetch failed: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(ToolError::Http(format!(
"/api/result/resolve returned {}: {}",
status.as_u16(),
body
)));
}
let body: JsonValue = response
.json()
.await
.map_err(|e| ToolError::Http(format!("Failed to parse JSON response: {e}")))?;
Ok(body)
}
fn resolve_bearer(value: &str, ctx: &ExecutionContext) -> String {
match ctx.get_secret(value) {
Some(token) => token.to_string(),
None => {
tracing::debug!(
alias_or_literal = %value,
"bearer_token didn't match a keychain alias; treating as literal token",
);
value.to_string()
}
}
}
fn build_flight_config(
cfg: &ResultFetchConfig,
ctx: &ExecutionContext,
) -> Result<Option<noetl_arrow_flight_client::FlightConfig>, FlightFetchError> {
use noetl_arrow_flight_client::{FlightAuth, FlightConfig, FlightTlsConfig};
match (&cfg.client_cert_path, &cfg.client_key_path) {
(Some(_), None) | (None, Some(_)) => {
return Err(FlightFetchError::Transport(format!(
"client_cert_path and client_key_path must both be set or both be None; \
got cert={cert:?}, key={key:?}",
cert = cfg.client_cert_path,
key = cfg.client_key_path,
)));
}
_ => {}
}
let mut tls: Option<FlightTlsConfig> = None;
if let Some(ca_path) = &cfg.tls_ca_path {
let ca_pem = std::fs::read(ca_path).map_err(|e| {
FlightFetchError::Transport(format!("read TLS CA bundle from {ca_path}: {e}"))
})?;
tls = Some(tls.unwrap_or_default().ca_certificate(ca_pem));
}
if let (Some(cert_path), Some(key_path)) = (&cfg.client_cert_path, &cfg.client_key_path) {
let cert_pem = std::fs::read(cert_path).map_err(|e| {
FlightFetchError::Transport(format!("read client cert from {cert_path}: {e}"))
})?;
let key_pem = std::fs::read(key_path).map_err(|e| {
FlightFetchError::Transport(format!("read client key from {key_path}: {e}"))
})?;
tls = Some(tls.unwrap_or_default().identity(cert_pem, key_pem));
}
let mut out: Option<FlightConfig> = None;
if let Some(tls) = tls {
out = Some(out.unwrap_or_default().tls(tls));
}
if let Some(bearer_value) = &cfg.bearer_token {
let token = Self::resolve_bearer(bearer_value, ctx);
let auth = FlightAuth::bearer(token);
out = Some(out.unwrap_or_default().auth(auth));
}
Ok(out)
}
async fn fetch_via_flight(
&self,
cfg: &ResultFetchConfig,
ctx: &ExecutionContext,
) -> Result<JsonValue, FlightFetchError> {
let endpoint = cfg
.flight_endpoint
.clone()
.unwrap_or_else(|| Self::derive_flight_endpoint(&ctx.server_url));
let flight_config = Self::build_flight_config(cfg, ctx)?;
let resolver = match flight_config {
Some(config) => {
noetl_arrow_flight_client::FlightResolver::connect_with(&endpoint, config).await
}
None => noetl_arrow_flight_client::FlightResolver::connect(&endpoint).await,
}
.map_err(|e| {
FlightFetchError::Transport(format!("connect to Flight endpoint {endpoint}: {e}"))
})?;
match resolver.resolve_rows(&cfg.r#ref).await {
Ok(rows) => {
let columns: Vec<String> = rows
.first()
.and_then(|row| row.as_object())
.map(|obj| obj.keys().cloned().collect())
.unwrap_or_default();
Ok(serde_json::json!({
"data": {
"rows": rows,
"columns": columns,
"row_count": rows.len(),
},
"status": "success",
}))
}
Err(noetl_arrow_flight_client::FlightError::NonTabular { ref_uri, message }) => {
Err(FlightFetchError::NonTabular { ref_uri, message })
}
Err(noetl_arrow_flight_client::FlightError::Server(msg)) => {
Err(FlightFetchError::Server(msg))
}
Err(noetl_arrow_flight_client::FlightError::Transport(msg)) => {
Err(FlightFetchError::Transport(msg))
}
}
}
}
#[derive(Debug)]
enum FlightFetchError {
NonTabular { ref_uri: String, message: String },
Transport(String),
Server(String),
}
impl Default for ResultFetchTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for ResultFetchTool {
fn name(&self) -> &'static str {
"result_fetch"
}
async fn execute(
&self,
config: &ToolConfig,
ctx: &ExecutionContext,
) -> Result<ToolResult, ToolError> {
let cfg = self.parse_config(config, ctx)?;
let start = std::time::Instant::now();
tracing::debug!(
ref_uri = %cfg.r#ref,
prefer = ?cfg.prefer,
server_url = %ctx.server_url,
"Executing result_fetch",
);
let data = match cfg.prefer {
BackendPreference::Http => self.fetch_via_http(&cfg, ctx).await?,
BackendPreference::Flight => {
match self.fetch_via_flight(&cfg, ctx).await {
Ok(v) => v,
Err(FlightFetchError::NonTabular { ref_uri, message }) => {
tracing::debug!(
ref_uri = %ref_uri,
message = %message,
"Flight signalled non-tabular; falling back to HTTP",
);
self.fetch_via_http(&cfg, ctx).await?
}
Err(FlightFetchError::Transport(msg)) => {
tracing::warn!(
ref_uri = %cfg.r#ref,
error = %msg,
"Flight transport failed; falling back to HTTP",
);
self.fetch_via_http(&cfg, ctx).await?
}
Err(FlightFetchError::Server(msg)) => {
return Err(ToolError::Http(format!(
"Flight server error for ref {}: {}",
cfg.r#ref, msg
)));
}
}
}
};
let duration_ms = start.elapsed().as_millis() as u64;
Ok(ToolResult::success(data).with_duration(duration_ms))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn derive_flight_endpoint_swaps_port_8082_to_8083() {
assert_eq!(
ResultFetchTool::derive_flight_endpoint("http://noetl.noetl.svc.cluster.local:8082"),
"http://noetl.noetl.svc.cluster.local:8083",
);
assert_eq!(
ResultFetchTool::derive_flight_endpoint("http://localhost:8082"),
"http://localhost:8083",
);
assert_eq!(
ResultFetchTool::derive_flight_endpoint("https://noetl.example.com:8082"),
"https://noetl.example.com:8083",
);
}
#[test]
fn derive_flight_endpoint_passes_through_non_8082() {
assert_eq!(
ResultFetchTool::derive_flight_endpoint("http://noetl.example.com"),
"http://noetl.example.com",
);
assert_eq!(
ResultFetchTool::derive_flight_endpoint("http://noetl.example.com:9000"),
"http://noetl.example.com:9000",
);
}
#[test]
fn derive_flight_endpoint_defaults_to_http_when_scheme_missing() {
assert_eq!(
ResultFetchTool::derive_flight_endpoint("noetl.example.com:8082"),
"http://noetl.example.com:8083",
);
}
#[test]
fn default_prefer_is_flight() {
let cfg: ResultFetchConfig = serde_json::from_value(serde_json::json!({
"ref": "noetl://execution/12345/result/big_select/abcd1234",
}))
.unwrap();
assert_eq!(cfg.prefer, BackendPreference::Flight);
assert_eq!(
cfg.r#ref,
"noetl://execution/12345/result/big_select/abcd1234"
);
assert!(cfg.flight_endpoint.is_none());
}
#[test]
fn config_round_trips_http_preference() {
let cfg: ResultFetchConfig = serde_json::from_value(serde_json::json!({
"ref": "noetl://execution/1/result/x/y",
"prefer": "http",
}))
.unwrap();
assert_eq!(cfg.prefer, BackendPreference::Http);
}
#[test]
fn config_round_trips_explicit_flight_endpoint() {
let cfg: ResultFetchConfig = serde_json::from_value(serde_json::json!({
"ref": "noetl://execution/1/result/x/y",
"flight_endpoint": "grpc://other-server.example.com:9999",
}))
.unwrap();
assert_eq!(
cfg.flight_endpoint.as_deref(),
Some("grpc://other-server.example.com:9999"),
);
}
#[test]
fn tool_name_is_result_fetch() {
let tool = ResultFetchTool::new();
assert_eq!(tool.name(), "result_fetch");
}
#[test]
fn fetch_via_http_normalises_server_url_trailing_slash() {
let tool = ResultFetchTool::new();
let _ = tool.http_client; }
#[test]
fn config_round_trips_bearer_token() {
let cfg: ResultFetchConfig = serde_json::from_value(serde_json::json!({
"ref": "noetl://execution/1/result/x/y",
"bearer_token": "noetl_flight_token",
}))
.unwrap();
assert_eq!(cfg.bearer_token.as_deref(), Some("noetl_flight_token"));
assert!(cfg.tls_ca_path.is_none());
}
#[test]
fn config_round_trips_tls_ca_path() {
let cfg: ResultFetchConfig = serde_json::from_value(serde_json::json!({
"ref": "noetl://execution/1/result/x/y",
"tls_ca_path": "/etc/noetl/flight-ca.pem",
}))
.unwrap();
assert_eq!(cfg.tls_ca_path.as_deref(), Some("/etc/noetl/flight-ca.pem"),);
assert!(cfg.bearer_token.is_none());
}
#[test]
fn config_round_trips_full_auth_shape() {
let cfg: ResultFetchConfig = serde_json::from_value(serde_json::json!({
"ref": "noetl://execution/1/result/x/y",
"prefer": "flight",
"flight_endpoint": "https://noetl.example.com:8083",
"bearer_token": "noetl_flight_token",
"tls_ca_path": "/etc/noetl/flight-ca.pem",
}))
.unwrap();
assert_eq!(cfg.prefer, BackendPreference::Flight);
assert_eq!(cfg.bearer_token.as_deref(), Some("noetl_flight_token"));
assert_eq!(cfg.tls_ca_path.as_deref(), Some("/etc/noetl/flight-ca.pem"));
}
#[test]
fn config_round_trips_client_identity_paths() {
let cfg: ResultFetchConfig = serde_json::from_value(serde_json::json!({
"ref": "noetl://execution/1/result/x/y",
"client_cert_path": "/etc/noetl/worker-client.crt",
"client_key_path": "/etc/noetl/worker-client.key",
}))
.unwrap();
assert_eq!(
cfg.client_cert_path.as_deref(),
Some("/etc/noetl/worker-client.crt"),
);
assert_eq!(
cfg.client_key_path.as_deref(),
Some("/etc/noetl/worker-client.key"),
);
}
#[test]
fn config_round_trips_full_mtls_shape() {
let cfg: ResultFetchConfig = serde_json::from_value(serde_json::json!({
"ref": "noetl://execution/1/result/x/y",
"prefer": "flight",
"flight_endpoint": "https://noetl.example.com:8083",
"bearer_token": "noetl_flight_token",
"tls_ca_path": "/etc/noetl/flight-ca.pem",
"client_cert_path": "/etc/noetl/worker-client.crt",
"client_key_path": "/etc/noetl/worker-client.key",
}))
.unwrap();
assert_eq!(cfg.bearer_token.as_deref(), Some("noetl_flight_token"));
assert_eq!(cfg.tls_ca_path.as_deref(), Some("/etc/noetl/flight-ca.pem"));
assert_eq!(
cfg.client_cert_path.as_deref(),
Some("/etc/noetl/worker-client.crt"),
);
assert_eq!(
cfg.client_key_path.as_deref(),
Some("/etc/noetl/worker-client.key"),
);
}
#[test]
fn build_flight_config_rejects_partial_client_identity_cert_without_key() {
let cfg: ResultFetchConfig = serde_json::from_value(serde_json::json!({
"ref": "noetl://execution/1/result/x/y",
"client_cert_path": "/etc/noetl/worker-client.crt",
}))
.unwrap();
let ctx = ExecutionContext::new(1, "step", "http://noetl");
let err = ResultFetchTool::build_flight_config(&cfg, &ctx)
.expect_err("expected partial-pair error");
match err {
FlightFetchError::Transport(msg) => {
assert!(
msg.contains("client_cert_path and client_key_path must both be set"),
"got: {msg}"
);
}
other => panic!("expected Transport error, got {other:?}"),
}
}
#[test]
fn build_flight_config_rejects_partial_client_identity_key_without_cert() {
let cfg: ResultFetchConfig = serde_json::from_value(serde_json::json!({
"ref": "noetl://execution/1/result/x/y",
"client_key_path": "/etc/noetl/worker-client.key",
}))
.unwrap();
let ctx = ExecutionContext::new(1, "step", "http://noetl");
let err = ResultFetchTool::build_flight_config(&cfg, &ctx)
.expect_err("expected partial-pair error");
match err {
FlightFetchError::Transport(msg) => {
assert!(
msg.contains("client_cert_path and client_key_path must both be set"),
"got: {msg}"
);
}
other => panic!("expected Transport error, got {other:?}"),
}
}
#[test]
fn build_flight_config_some_for_client_identity_pair() {
let dir =
std::env::temp_dir().join(format!("noetl-tools-mtls-test-{}", std::process::id(),));
std::fs::create_dir_all(&dir).unwrap();
let cert_path = dir.join("client.crt");
let key_path = dir.join("client.key");
std::fs::write(
&cert_path,
"-----BEGIN CERTIFICATE-----\nfake\n-----END CERTIFICATE-----",
)
.unwrap();
std::fs::write(
&key_path,
"-----BEGIN PRIVATE KEY-----\nfake\n-----END PRIVATE KEY-----",
)
.unwrap();
let cfg: ResultFetchConfig = serde_json::from_value(serde_json::json!({
"ref": "noetl://execution/1/result/x/y",
"client_cert_path": cert_path.to_str().unwrap(),
"client_key_path": key_path.to_str().unwrap(),
}))
.unwrap();
let ctx = ExecutionContext::new(1, "step", "http://noetl");
let result = ResultFetchTool::build_flight_config(&cfg, &ctx).unwrap();
assert!(result.is_some());
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn build_flight_config_error_when_client_cert_unreadable() {
let cfg: ResultFetchConfig = serde_json::from_value(serde_json::json!({
"ref": "noetl://execution/1/result/x/y",
"client_cert_path": "/does/not/exist/client.crt",
"client_key_path": "/does/not/exist/client.key",
}))
.unwrap();
let ctx = ExecutionContext::new(1, "step", "http://noetl");
let err = ResultFetchTool::build_flight_config(&cfg, &ctx).expect_err("unreadable");
match err {
FlightFetchError::Transport(msg) => {
assert!(msg.contains("/does/not/exist"), "got: {msg}");
assert!(msg.contains("client"), "got: {msg}");
}
other => panic!("expected Transport error, got {other:?}"),
}
}
#[test]
fn resolve_bearer_finds_keychain_alias() {
let mut ctx = ExecutionContext::new(1, "step", "http://noetl");
ctx.set_secret("noetl_flight_token", "sk-ant-real-token-bytes");
let resolved = ResultFetchTool::resolve_bearer("noetl_flight_token", &ctx);
assert_eq!(resolved, "sk-ant-real-token-bytes");
}
#[test]
fn resolve_bearer_falls_through_for_unknown_alias() {
let ctx = ExecutionContext::new(1, "step", "http://noetl");
let resolved = ResultFetchTool::resolve_bearer("sk-literal-not-an-alias", &ctx);
assert_eq!(resolved, "sk-literal-not-an-alias");
}
#[test]
fn build_flight_config_none_when_no_auth_or_tls() {
let cfg: ResultFetchConfig = serde_json::from_value(serde_json::json!({
"ref": "noetl://execution/1/result/x/y",
}))
.unwrap();
let ctx = ExecutionContext::new(1, "step", "http://noetl");
let result = ResultFetchTool::build_flight_config(&cfg, &ctx).unwrap();
assert!(result.is_none());
}
#[test]
fn build_flight_config_some_for_bearer_only() {
let cfg: ResultFetchConfig = serde_json::from_value(serde_json::json!({
"ref": "noetl://execution/1/result/x/y",
"bearer_token": "tok-literal",
}))
.unwrap();
let ctx = ExecutionContext::new(1, "step", "http://noetl");
let result = ResultFetchTool::build_flight_config(&cfg, &ctx).unwrap();
assert!(result.is_some());
}
#[test]
fn build_flight_config_some_for_tls_only() {
let dir =
std::env::temp_dir().join(format!("noetl-tools-tls-test-{}", std::process::id(),));
std::fs::create_dir_all(&dir).unwrap();
let ca_path = dir.join("ca.pem");
std::fs::write(
&ca_path,
"-----BEGIN CERTIFICATE-----\nfake\n-----END CERTIFICATE-----",
)
.unwrap();
let cfg: ResultFetchConfig = serde_json::from_value(serde_json::json!({
"ref": "noetl://execution/1/result/x/y",
"tls_ca_path": ca_path.to_str().unwrap(),
}))
.unwrap();
let ctx = ExecutionContext::new(1, "step", "http://noetl");
let result = ResultFetchTool::build_flight_config(&cfg, &ctx).unwrap();
assert!(result.is_some());
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn build_flight_config_error_when_ca_path_unreadable() {
let cfg: ResultFetchConfig = serde_json::from_value(serde_json::json!({
"ref": "noetl://execution/1/result/x/y",
"tls_ca_path": "/does/not/exist.pem",
}))
.unwrap();
let ctx = ExecutionContext::new(1, "step", "http://noetl");
let err = ResultFetchTool::build_flight_config(&cfg, &ctx).expect_err("unreadable");
match err {
FlightFetchError::Transport(msg) => {
assert!(msg.contains("/does/not/exist.pem"), "got: {msg}");
assert!(msg.contains("TLS CA bundle"), "got: {msg}");
}
other => panic!("expected Transport error, got {other:?}"),
}
}
}