use serde::{Deserialize, Serialize};
use std::path::PathBuf;
pub const PROTOCOL_VERSION: u32 = 1;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Shell {
Zsh,
Bash,
Fish,
#[serde(alias = "pwsh")]
PowerShell,
#[serde(alias = "nu")]
Nushell,
}
impl std::str::FromStr for Shell {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let lower = s.to_lowercase();
let base = lower.split('-').next().unwrap_or(&lower);
match base {
"zsh" => Ok(Shell::Zsh),
"bash" | "sh" => Ok(Shell::Bash),
"fish" => Ok(Shell::Fish),
"powershell" | "pwsh" => Ok(Shell::PowerShell),
"nushell" | "nu" => Ok(Shell::Nushell),
_ => Err(format!("unknown shell: {s}")),
}
}
}
impl Shell {
pub fn as_str(&self) -> &'static str {
match self {
Shell::Zsh => "zsh",
Shell::Bash => "bash",
Shell::Fish => "fish",
Shell::PowerShell => "powershell",
Shell::Nushell => "nushell",
}
}
pub fn detect_default() -> Self {
Self::detect_from(
std::env::var("NIGHTHAWK_SHELL").ok(),
std::env::var("SHELL").ok(),
)
}
pub fn detect_from(nighthawk_shell: Option<String>, shell_env: Option<String>) -> Self {
if let Some(ref s) = nighthawk_shell {
if let Ok(shell) = s.parse::<Shell>() {
return shell;
}
}
let _ = &shell_env; #[cfg(windows)]
{
Shell::PowerShell
}
#[cfg(not(windows))]
{
shell_env
.as_deref()
.and_then(|s| s.rsplit('/').next())
.and_then(|name| name.parse::<Shell>().ok())
.unwrap_or(Shell::Zsh)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionRequest {
pub input: String,
pub cursor: usize,
pub cwd: PathBuf,
pub shell: Shell,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionResponse {
pub suggestions: Vec<Suggestion>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Suggestion {
pub text: String,
pub replace_start: usize,
pub replace_end: usize,
pub confidence: f32,
pub source: SuggestionSource,
pub description: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub diff_ops: Option<Vec<DiffOp>>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "op", content = "ch")]
pub enum DiffOp {
Keep(char),
Delete(char),
Insert(char),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SuggestionSource {
History,
Spec,
LocalModel,
CloudModel,
}
pub fn default_socket_path() -> PathBuf {
#[cfg(unix)]
{
let uid = unsafe { libc::getuid() };
PathBuf::from(format!("/tmp/nighthawk-{}.sock", uid))
}
#[cfg(windows)]
{
PathBuf::from(r"\\.\pipe\nighthawk")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn request_roundtrip_json() {
let req = CompletionRequest {
input: "git ch".into(),
cursor: 6,
cwd: PathBuf::from("/home/user/project"),
shell: Shell::Zsh,
};
let json = serde_json::to_string(&req).unwrap();
let parsed: CompletionRequest = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.input, "git ch");
assert_eq!(parsed.cursor, 6);
assert_eq!(parsed.shell, Shell::Zsh);
}
#[test]
fn response_roundtrip_json() {
let resp = CompletionResponse {
suggestions: vec![Suggestion {
text: "checkout".into(),
replace_start: 4,
replace_end: 6,
confidence: 0.95,
source: SuggestionSource::Spec,
description: Some("Switch branches or restore files".into()),
diff_ops: None,
}],
};
let json = serde_json::to_string(&resp).unwrap();
let parsed: CompletionResponse = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.suggestions.len(), 1);
assert_eq!(parsed.suggestions[0].text, "checkout");
assert_eq!(parsed.suggestions[0].replace_start, 4);
}
#[test]
fn shell_serde() {
let s = Shell::PowerShell;
let json = serde_json::to_string(&s).unwrap();
assert_eq!(json, "\"powershell\"");
let parsed: Shell = serde_json::from_str("\"pwsh\"").unwrap();
assert_eq!(parsed, Shell::PowerShell);
}
#[test]
fn diff_op_roundtrip() {
let ops = vec![
DiffOp::Keep('c'),
DiffOp::Delete('a'),
DiffOp::Insert('e'),
DiffOp::Keep('k'),
];
let json = serde_json::to_string(&ops).unwrap();
let parsed: Vec<DiffOp> = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, ops);
}
#[test]
fn suggestion_with_diff_ops() {
let suggestion = Suggestion {
text: "checkout".into(),
replace_start: 4,
replace_end: 12,
confidence: 0.7,
source: SuggestionSource::Spec,
description: None,
diff_ops: Some(vec![DiffOp::Keep('c'), DiffOp::Insert('h')]),
};
let json = serde_json::to_string(&suggestion).unwrap();
assert!(json.contains("diff_ops"));
let parsed: Suggestion = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.diff_ops.unwrap().len(), 2);
}
#[test]
fn suggestion_without_diff_ops_backward_compat() {
let json = r#"{"text":"checkout","replace_start":4,"replace_end":6,"confidence":0.9,"source":"spec","description":null}"#;
let parsed: Suggestion = serde_json::from_str(json).unwrap();
assert!(parsed.diff_ops.is_none());
}
#[test]
fn powershell_request_with_windows_path() {
let json = r#"{"input":"cd C:\\Users\\iamsu","cursor":18,"cwd":"D:\\projects\\nighthawk","shell":"powershell"}"#;
let parsed: CompletionRequest = serde_json::from_str(json).unwrap();
assert_eq!(parsed.input, r"cd C:\Users\iamsu");
assert_eq!(parsed.cursor, 18);
assert_eq!(parsed.cwd, PathBuf::from(r"D:\projects\nighthawk"));
assert_eq!(parsed.shell, Shell::PowerShell);
}
#[test]
fn powershell_request_with_quotes_in_input() {
let json = r#"{"input":"echo \"hello world\"","cursor":20,"cwd":"C:\\","shell":"pwsh"}"#;
let parsed: CompletionRequest = serde_json::from_str(json).unwrap();
assert_eq!(parsed.input, r#"echo "hello world""#);
assert_eq!(parsed.shell, Shell::PowerShell); }
#[test]
fn powershell_request_unc_path() {
let json = r#"{"input":"dir","cursor":3,"cwd":"\\\\server\\share","shell":"powershell"}"#;
let parsed: CompletionRequest = serde_json::from_str(json).unwrap();
assert_eq!(parsed.cwd, PathBuf::from(r"\\server\share"));
}
#[test]
fn suggestion_none_diff_ops_omitted_in_json() {
let suggestion = Suggestion {
text: "checkout".into(),
replace_start: 4,
replace_end: 6,
confidence: 0.9,
source: SuggestionSource::Spec,
description: None,
diff_ops: None,
};
let json = serde_json::to_string(&suggestion).unwrap();
assert!(
!json.contains("diff_ops"),
"None diff_ops should be omitted: {}",
json
);
}
#[test]
fn shell_from_str_basics() {
assert_eq!("zsh".parse::<Shell>().unwrap(), Shell::Zsh);
assert_eq!("bash".parse::<Shell>().unwrap(), Shell::Bash);
assert_eq!("fish".parse::<Shell>().unwrap(), Shell::Fish);
assert_eq!("powershell".parse::<Shell>().unwrap(), Shell::PowerShell);
assert_eq!("pwsh".parse::<Shell>().unwrap(), Shell::PowerShell);
assert_eq!("nushell".parse::<Shell>().unwrap(), Shell::Nushell);
assert_eq!("nu".parse::<Shell>().unwrap(), Shell::Nushell);
assert_eq!("sh".parse::<Shell>().unwrap(), Shell::Bash);
}
#[test]
fn shell_from_str_case_insensitive() {
assert_eq!("ZSH".parse::<Shell>().unwrap(), Shell::Zsh);
assert_eq!("PowerShell".parse::<Shell>().unwrap(), Shell::PowerShell);
assert_eq!("BASH".parse::<Shell>().unwrap(), Shell::Bash);
assert_eq!("Fish".parse::<Shell>().unwrap(), Shell::Fish);
}
#[test]
fn shell_from_str_versioned() {
assert_eq!("bash-5.2".parse::<Shell>().unwrap(), Shell::Bash);
assert_eq!("zsh-5.9".parse::<Shell>().unwrap(), Shell::Zsh);
}
#[test]
fn shell_from_str_unknown() {
assert!("ksh".parse::<Shell>().is_err());
assert!("csh".parse::<Shell>().is_err());
assert!("tcsh".parse::<Shell>().is_err());
}
#[test]
fn shell_from_str_empty() {
assert!("".parse::<Shell>().is_err());
}
#[test]
fn detect_from_nighthawk_shell_override() {
let shell = Shell::detect_from(Some("powershell".into()), Some("/bin/zsh".into()));
assert_eq!(shell, Shell::PowerShell);
}
#[test]
fn detect_from_unknown_override_falls_through() {
let shell = Shell::detect_from(Some("ksh".into()), Some("/bin/fish".into()));
#[cfg(not(windows))]
assert_eq!(shell, Shell::Fish);
#[cfg(windows)]
assert_eq!(shell, Shell::PowerShell);
}
#[test]
fn detect_from_shell_path_parsing() {
let shell = Shell::detect_from(None, Some("/usr/local/bin/fish".into()));
#[cfg(not(windows))]
assert_eq!(shell, Shell::Fish);
#[cfg(windows)]
assert_eq!(shell, Shell::PowerShell);
}
#[test]
fn detect_from_shell_pwsh_on_unix() {
let shell = Shell::detect_from(None, Some("/usr/bin/pwsh".into()));
#[cfg(not(windows))]
assert_eq!(shell, Shell::PowerShell);
#[cfg(windows)]
assert_eq!(shell, Shell::PowerShell);
}
#[test]
fn detect_from_no_env_vars() {
let shell = Shell::detect_from(None, None);
#[cfg(not(windows))]
assert_eq!(shell, Shell::Zsh);
#[cfg(windows)]
assert_eq!(shell, Shell::PowerShell);
}
#[test]
fn shell_serde_nu_alias() {
let parsed: Shell = serde_json::from_str("\"nu\"").unwrap();
assert_eq!(parsed, Shell::Nushell);
}
#[test]
fn shell_from_str_roundtrip() {
let all_shells = [
Shell::Zsh,
Shell::Bash,
Shell::Fish,
Shell::PowerShell,
Shell::Nushell,
];
for shell in all_shells {
assert_eq!(
shell.as_str().parse::<Shell>().unwrap(),
shell,
"roundtrip failed for {:?}",
shell,
);
}
}
}