use sen_plugin_api::Capabilities;
use std::io::{self, BufRead, Write};
use thiserror::Error;
use super::store::StoredTrustLevel;
#[derive(Debug, Error)]
pub enum PromptError {
#[error("Prompt cancelled by user")]
Cancelled,
#[error("Non-interactive environment")]
NonInteractive,
#[error("I/O error: {0}")]
IoError(#[from] io::Error),
#[error("Timeout waiting for user response")]
Timeout,
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum PromptResult {
AllowOnce,
AllowSession,
AllowAlways,
#[default]
Deny,
}
impl PromptResult {
pub fn to_trust_level(&self) -> Option<StoredTrustLevel> {
match self {
Self::AllowOnce => None,
Self::AllowSession => Some(StoredTrustLevel::Session),
Self::AllowAlways => Some(StoredTrustLevel::Permanent),
Self::Deny => None,
}
}
pub fn is_allowed(&self) -> bool {
matches!(
self,
Self::AllowOnce | Self::AllowSession | Self::AllowAlways
)
}
pub fn should_persist(&self) -> bool {
matches!(self, Self::AllowSession | Self::AllowAlways)
}
}
pub trait PromptHandler: Send + Sync {
fn prompt(
&self,
plugin: &str,
capabilities: &Capabilities,
) -> Result<PromptResult, PromptError>;
fn is_interactive(&self) -> bool;
fn prompt_escalation(
&self,
plugin: &str,
old_caps: &Capabilities,
new_caps: &Capabilities,
) -> Result<PromptResult, PromptError> {
let _ = old_caps;
self.prompt(plugin, new_caps)
}
}
#[derive(Debug)]
pub struct TerminalPromptHandler {
verbose: bool,
}
impl TerminalPromptHandler {
pub fn new() -> Self {
Self { verbose: true }
}
pub fn minimal() -> Self {
Self { verbose: false }
}
fn format_capabilities(&self, caps: &Capabilities) -> String {
let mut lines = Vec::new();
if !caps.fs_read.is_empty() {
for path in &caps.fs_read {
let recursive = if path.recursive { " (recursive)" } else { "" };
lines.push(format!(" - Read files in: {}{}", path.pattern, recursive));
}
}
if !caps.fs_write.is_empty() {
for path in &caps.fs_write {
let recursive = if path.recursive { " (recursive)" } else { "" };
lines.push(format!(" - Write files in: {}{}", path.pattern, recursive));
}
}
if !caps.env_read.is_empty() {
let vars = caps.env_read.join(", ");
lines.push(format!(" - Read environment: {}", vars));
}
if !caps.net.is_empty() {
for net in &caps.net {
let port_str = net.port.map(|p| format!(":{}", p)).unwrap_or_default();
lines.push(format!(" - Network access: {}{}", net.host, port_str));
}
}
if caps.stdio.stdin {
lines.push(" - Read from stdin".to_string());
}
if caps.stdio.stdout {
lines.push(" - Write to stdout".to_string());
}
if caps.stdio.stderr {
lines.push(" - Write to stderr".to_string());
}
lines.join("\n")
}
}
impl Default for TerminalPromptHandler {
fn default() -> Self {
Self::new()
}
}
impl PromptHandler for TerminalPromptHandler {
fn prompt(
&self,
plugin: &str,
capabilities: &Capabilities,
) -> Result<PromptResult, PromptError> {
let stdin = io::stdin();
let mut stdout = io::stdout();
if !atty_check() {
return Err(PromptError::NonInteractive);
}
writeln!(stdout)?;
writeln!(
stdout,
"Plugin \"{}\" requests the following permissions:",
plugin
)?;
writeln!(stdout)?;
if self.verbose {
writeln!(stdout, "{}", self.format_capabilities(capabilities))?;
writeln!(stdout)?;
}
write!(stdout, "Allow? [y]es / [n]o / [a]lways / [s]ession: ")?;
stdout.flush()?;
let mut input = String::new();
stdin.lock().read_line(&mut input)?;
let input = input.trim().to_lowercase();
match input.as_str() {
"y" | "yes" => Ok(PromptResult::AllowOnce),
"n" | "no" => Ok(PromptResult::Deny),
"a" | "always" => Ok(PromptResult::AllowAlways),
"s" | "session" => Ok(PromptResult::AllowSession),
"" => Ok(PromptResult::Deny), _ => {
writeln!(stdout, "Invalid input, defaulting to deny")?;
Ok(PromptResult::Deny)
}
}
}
fn is_interactive(&self) -> bool {
atty_check()
}
fn prompt_escalation(
&self,
plugin: &str,
old_caps: &Capabilities,
new_caps: &Capabilities,
) -> Result<PromptResult, PromptError> {
let stdin = io::stdin();
let mut stdout = io::stdout();
if !atty_check() {
return Err(PromptError::NonInteractive);
}
writeln!(stdout)?;
writeln!(
stdout,
"WARNING: Plugin \"{}\" requests ADDITIONAL permissions!",
plugin
)?;
writeln!(stdout)?;
if self.verbose {
writeln!(stdout, "Previously granted:")?;
writeln!(stdout, "{}", self.format_capabilities(old_caps))?;
writeln!(stdout)?;
writeln!(stdout, "Now requesting:")?;
writeln!(stdout, "{}", self.format_capabilities(new_caps))?;
writeln!(stdout)?;
}
write!(stdout, "Allow escalation? [y]es / [n]o / [a]lways: ")?;
stdout.flush()?;
let mut input = String::new();
stdin.lock().read_line(&mut input)?;
let input = input.trim().to_lowercase();
match input.as_str() {
"y" | "yes" => Ok(PromptResult::AllowOnce),
"n" | "no" => Ok(PromptResult::Deny),
"a" | "always" => Ok(PromptResult::AllowAlways),
_ => Ok(PromptResult::Deny),
}
}
}
#[derive(Debug)]
pub struct AutoPromptHandler {
default_response: PromptResult,
}
impl AutoPromptHandler {
pub fn always_allow() -> Self {
Self {
default_response: PromptResult::AllowAlways,
}
}
pub fn always_deny() -> Self {
Self {
default_response: PromptResult::Deny,
}
}
pub fn with_response(response: PromptResult) -> Self {
Self {
default_response: response,
}
}
}
impl PromptHandler for AutoPromptHandler {
fn prompt(
&self,
_plugin: &str,
_capabilities: &Capabilities,
) -> Result<PromptResult, PromptError> {
Ok(self.default_response.clone())
}
fn is_interactive(&self) -> bool {
false
}
}
#[derive(Debug, Default)]
pub struct RecordingPromptHandler {
prompts: std::sync::Mutex<Vec<RecordedPrompt>>,
response: PromptResult,
}
#[derive(Debug, Clone)]
pub struct RecordedPrompt {
pub plugin: String,
pub capabilities_hash: String,
pub is_escalation: bool,
}
impl RecordingPromptHandler {
pub fn new(response: PromptResult) -> Self {
Self {
prompts: std::sync::Mutex::new(Vec::new()),
response,
}
}
pub fn prompts(&self) -> Vec<RecordedPrompt> {
self.prompts
.lock()
.expect("RecordingPromptHandler mutex poisoned")
.clone()
}
pub fn prompt_count(&self) -> usize {
self.prompts
.lock()
.expect("RecordingPromptHandler mutex poisoned")
.len()
}
pub fn clear(&self) {
self.prompts
.lock()
.expect("RecordingPromptHandler mutex poisoned")
.clear();
}
}
impl PromptHandler for RecordingPromptHandler {
fn prompt(
&self,
plugin: &str,
capabilities: &Capabilities,
) -> Result<PromptResult, PromptError> {
self.prompts
.lock()
.expect("RecordingPromptHandler mutex poisoned")
.push(RecordedPrompt {
plugin: plugin.to_string(),
capabilities_hash: capabilities.compute_hash(),
is_escalation: false,
});
Ok(self.response.clone())
}
fn is_interactive(&self) -> bool {
false
}
fn prompt_escalation(
&self,
plugin: &str,
_old_caps: &Capabilities,
new_caps: &Capabilities,
) -> Result<PromptResult, PromptError> {
self.prompts
.lock()
.expect("RecordingPromptHandler mutex poisoned")
.push(RecordedPrompt {
plugin: plugin.to_string(),
capabilities_hash: new_caps.compute_hash(),
is_escalation: true,
});
Ok(self.response.clone())
}
}
fn atty_check() -> bool {
#[cfg(unix)]
{
use std::os::unix::io::AsRawFd;
unsafe { libc::isatty(std::io::stdout().as_raw_fd()) != 0 }
}
#[cfg(windows)]
{
use std::os::windows::io::AsRawHandle;
use windows_sys::Win32::System::Console::{GetConsoleMode, CONSOLE_MODE};
let handle = std::io::stdout().as_raw_handle();
let mut mode: CONSOLE_MODE = 0;
unsafe { GetConsoleMode(handle as _, &mut mode) != 0 }
}
#[cfg(not(any(unix, windows)))]
{
std::env::var("TERM").is_ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
use sen_plugin_api::PathPattern;
#[test]
fn test_prompt_result() {
assert!(PromptResult::AllowOnce.is_allowed());
assert!(PromptResult::AllowAlways.is_allowed());
assert!(!PromptResult::Deny.is_allowed());
assert!(!PromptResult::AllowOnce.should_persist());
assert!(PromptResult::AllowAlways.should_persist());
assert!(PromptResult::AllowSession.should_persist());
}
#[test]
fn test_auto_handler() {
let handler = AutoPromptHandler::always_allow();
let caps = Capabilities::none();
let result = handler.prompt("test", &caps).unwrap();
assert_eq!(result, PromptResult::AllowAlways);
let handler = AutoPromptHandler::always_deny();
let result = handler.prompt("test", &caps).unwrap();
assert_eq!(result, PromptResult::Deny);
}
#[test]
fn test_recording_handler() {
let handler = RecordingPromptHandler::new(PromptResult::AllowOnce);
let caps = Capabilities::default().with_fs_read(vec![PathPattern::new("./data")]);
handler.prompt("plugin1", &caps).unwrap();
handler.prompt("plugin2", &caps).unwrap();
assert_eq!(handler.prompt_count(), 2);
let prompts = handler.prompts();
assert_eq!(prompts[0].plugin, "plugin1");
assert_eq!(prompts[1].plugin, "plugin2");
}
#[test]
fn test_format_capabilities() {
let handler = TerminalPromptHandler::new();
let caps = Capabilities::default()
.with_fs_read(vec![PathPattern::new("./data").recursive()])
.with_fs_write(vec![PathPattern::new("./output")])
.with_env_read(vec!["HOME".into(), "PATH".into()]);
let formatted = handler.format_capabilities(&caps);
assert!(formatted.contains("./data"));
assert!(formatted.contains("recursive"));
assert!(formatted.contains("./output"));
assert!(formatted.contains("HOME"));
}
}