nabla_cli/routes/
binary.rs

1// src/routes/binary.rs
2use anyhow::Result;
3use axum::{
4    extract::{Multipart, Request, State},
5    http::StatusCode,
6    response::Json,
7};
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10use uuid::Uuid;
11
12// Add this line to import the inference module
13use crate::enterprise::providers::{
14    GenerationOptions, GenerationResponse, HTTPProvider, InferenceProvider,
15};
16
17// Type alias for JSON responses
18// Removed custom ResponseJson type alias
19use crate::{
20    AppState,
21    binary::{BinaryAnalysis, VulnerabilityMatch, analyze_binary, scan_binary_vulnerabilities},
22    middleware::PlanFeatures,
23};
24
25/// Validates and sanitizes a file path to prevent path traversal attacks
26/// Returns the canonicalized path if valid, or an error if the path is unsafe
27pub fn validate_file_path(
28    file_path: &str,
29) -> Result<std::path::PathBuf, (StatusCode, Json<ErrorResponse>)> {
30    // 1. Check for path traversal attempts (..) using string operations
31    if file_path.contains("..") {
32        return Err((
33            StatusCode::BAD_REQUEST,
34            Json(ErrorResponse {
35                error: "invalid_input".to_string(),
36                message: "Path traversal not allowed".to_string(),
37            }),
38        ));
39    }
40
41    // 2. Check for absolute paths using string operations
42    if file_path.starts_with('/') || (cfg!(windows) && file_path.contains(':')) {
43        return Err((
44            StatusCode::BAD_REQUEST,
45            Json(ErrorResponse {
46                error: "invalid_input".to_string(),
47                message: "Absolute paths not allowed".to_string(),
48            }),
49        ));
50    }
51
52    // 3. Create path only after validation
53    let path = std::path::Path::new(file_path);
54
55    // 4. Define allowed directory (restrict to current working directory)
56    let base_dir = std::env::current_dir().map_err(|_e| {
57        (
58            StatusCode::INTERNAL_SERVER_ERROR,
59            Json(ErrorResponse {
60                error: "server_error".to_string(),
61                message: "Failed to get current directory".to_string(),
62            }),
63        )
64    })?;
65
66    // 5. Build the full path and canonicalize it
67    let full_path = base_dir.join(path);
68    let canonical_path = full_path.canonicalize().map_err(|_e| {
69        (
70            StatusCode::BAD_REQUEST,
71            Json(ErrorResponse {
72                error: "invalid_input".to_string(),
73                message: "Invalid file path".to_string(),
74            }),
75        )
76    })?;
77
78    // 6. Security check: Ensure the canonicalized path is within the allowed directory
79    if !canonical_path.starts_with(&base_dir) {
80        return Err((
81            StatusCode::BAD_REQUEST,
82            Json(ErrorResponse {
83                error: "invalid_input".to_string(),
84                message: "Access denied: Path outside allowed directory".to_string(),
85            }),
86        ));
87    }
88
89    // 7. Check if file exists and is a regular file (not a symlink or directory)
90    if !canonical_path.exists() {
91        return Err((
92            StatusCode::BAD_REQUEST,
93            Json(ErrorResponse {
94                error: "file_not_found".to_string(),
95                message: "File not found".to_string(),
96            }),
97        ));
98    }
99
100    // 8. Check if it's a regular file (not a symlink, directory, etc.)
101    let metadata = std::fs::metadata(&canonical_path).map_err(|_e| {
102        (
103            StatusCode::BAD_REQUEST,
104            Json(ErrorResponse {
105                error: "file_error".to_string(),
106                message: "Cannot access file".to_string(),
107            }),
108        )
109    })?;
110
111    if !metadata.is_file() {
112        return Err((
113            StatusCode::BAD_REQUEST,
114            Json(ErrorResponse {
115                error: "invalid_input".to_string(),
116                message: "Path is not a regular file".to_string(),
117            }),
118        ));
119    }
120
121    Ok(canonical_path)
122}
123
124#[derive(Debug, Serialize)]
125pub struct BinaryUploadResponse {
126    pub id: Uuid,
127    pub hash: String,
128    pub analysis: BinaryAnalysis,
129}
130
131#[derive(Debug, Serialize)]
132pub struct ErrorResponse {
133    pub error: String,
134    pub message: String,
135}
136
137#[derive(Debug, Serialize)]
138pub struct CveScanResponse {
139    pub matches: Vec<VulnerabilityMatch>,
140}
141
142#[derive(Debug, Deserialize)]
143pub struct ChatRequest {
144    pub file_path: String, // Path to the file instead of raw content
145    pub question: String,
146    pub model_path: Option<String>,     // For local GGUF files
147    pub hf_repo: Option<String>,        // For remote HF repos
148    pub provider: String,               // "http" for HTTP provider
149    pub inference_url: Option<String>,  // URL for the inference server
150    pub provider_token: Option<String>, // Token for third-party authentication
151    pub options: Option<GenerationOptions>,
152}
153
154#[derive(Debug, Serialize)]
155pub struct ChatResponse {
156    pub answer: String,
157    pub model_used: String,
158    pub tokens_used: usize,
159}
160
161pub async fn health_check(State(mut state): State<AppState>) -> Json<serde_json::Value> {
162    let fips_status = if state.config.fips_mode {
163        state.crypto_provider.validate_fips_compliance().is_ok()
164    } else {
165        false
166    };
167
168    let fips_details = if state.config.fips_mode {
169        json!({
170            "fips_mode": true,
171            "fips_compliant": fips_status,
172            "fips_validation": state.config.fips_validation,
173            "approved_algorithms": [
174                "SHA-256",
175                "SHA-512",
176                "HMAC-SHA256",
177                "AES-256-GCM",
178                "TLS13_AES_256_GCM_SHA384"
179            ],
180            "hash_algorithm": "SHA-512",
181            "random_generator": "FIPS-compliant OS RNG"
182        })
183    } else {
184        json!({
185            "fips_mode": false,
186            "fips_compliant": false,
187            "fips_validation": false,
188            "hash_algorithm": "Blake3",
189            "random_generator": "Standard RNG"
190        })
191    };
192
193    Json(json!({
194        "status": "healthy",
195        "service": "Nabla",
196        "version": env!("CARGO_PKG_VERSION"),
197        "fips": fips_details
198    }))
199}
200
201// POST /binary - Upload and analyze binary
202pub async fn upload_and_analyze_binary(
203    State(state): State<AppState>,
204    mut multipart: Multipart,
205) -> Result<Json<BinaryUploadResponse>, (StatusCode, Json<ErrorResponse>)> {
206    let mut file_name = "unknown".to_string();
207    let mut contents = vec![];
208    let mut found_file = false;
209
210    // Extract file from multipart form
211    while let Some(field) = multipart.next_field().await.map_err(|e| {
212        (
213            StatusCode::BAD_REQUEST,
214            Json(ErrorResponse {
215                error: "multipart_error".to_string(),
216                message: format!("Failed to parse multipart form: {}", e),
217            }),
218        )
219    })? {
220        let field_name = field.name().unwrap_or("unknown_field").to_string();
221        tracing::debug!("Processing multipart field: '{}'", field_name);
222
223        // Get filename if present
224        let field_filename = field.file_name().map(|s| s.to_string());
225        if let Some(name) = &field_filename {
226            file_name = name.clone();
227            tracing::info!("Found filename in multipart: '{}'", file_name);
228        }
229
230        // Read field contents
231        let field_contents = field
232            .bytes()
233            .await
234            .map_err(|e| {
235                (
236                    StatusCode::BAD_REQUEST,
237                    Json(ErrorResponse {
238                        error: "read_error".to_string(),
239                        message: format!("Failed to read field '{}' contents: {}", field_name, e),
240                    }),
241                )
242            })?
243            .to_vec();
244
245        tracing::debug!(
246            "Field '{}': {} bytes, filename: {:?}",
247            field_name,
248            field_contents.len(),
249            field_filename
250        );
251
252        // Only use content from file fields, not text fields
253        if !field_contents.is_empty()
254            && (
255                field_name == "file"
256                    || field_name == "binary"
257                    || field_filename.is_some()
258                    || field_contents.len() > 10
259                // Assume larger content is the file
260            )
261        {
262            contents = field_contents;
263            found_file = true;
264            tracing::info!(
265                "Using {} bytes from field '{}' as file content",
266                contents.len(),
267                field_name
268            );
269        }
270    }
271
272    if !found_file {
273        tracing::warn!("No file field found in multipart form");
274    }
275
276    if contents.is_empty() {
277        return Err((
278            StatusCode::BAD_REQUEST,
279            Json(ErrorResponse {
280                error: "empty_file".to_string(),
281                message: "No file content provided".to_string(),
282            }),
283        ));
284    }
285
286    // Log the received file info
287    tracing::info!("Analyzing file: '{}' ({} bytes)", file_name, contents.len());
288
289    // Analyze the binary
290    let analysis = analyze_binary(&file_name, &contents, &state.crypto_provider)
291        .await
292        .map_err(|e| {
293            tracing::error!("Binary analysis failed: {}", e);
294            (
295                StatusCode::INTERNAL_SERVER_ERROR,
296                Json(ErrorResponse {
297                    error: "analysis_error".to_string(),
298                    message: format!("Failed to analyze binary: {}", e),
299                }),
300            )
301        })?;
302
303    tracing::info!(
304        "Analysis completed for {}: format={}, arch={}, {} strings",
305        file_name,
306        analysis.format,
307        analysis.architecture,
308        analysis.embedded_strings.len()
309    );
310
311    Ok(Json(BinaryUploadResponse {
312        id: analysis.id,
313        hash: analysis.hash_sha256.clone(),
314        analysis,
315    }))
316}
317
318pub async fn check_cve(
319    State(state): State<AppState>,
320    mut multipart: Multipart,
321) -> Result<Json<CveScanResponse>, (StatusCode, Json<ErrorResponse>)> {
322    tracing::info!("check_cve handler called");
323
324    let mut contents = vec![];
325    let mut file_name = "uploaded.bin".to_string();
326
327    while let Some(field) = multipart.next_field().await.map_err(|e| {
328        tracing::error!("Error parsing multipart: {}", e);
329        (
330            StatusCode::BAD_REQUEST,
331            Json(ErrorResponse {
332                error: "multipart_error".to_string(),
333                message: format!("Failed to parse multipart form: {}", e),
334            }),
335        )
336    })? {
337        tracing::info!("Found field in multipart: {:?}", field.name());
338
339        if let Some(name) = field.file_name() {
340            file_name = name.to_string();
341            tracing::info!("Uploaded file: {}", file_name);
342        }
343
344        contents = field
345            .bytes()
346            .await
347            .map_err(|e| {
348                tracing::error!("Error reading file: {}", e);
349                (
350                    StatusCode::BAD_REQUEST,
351                    Json(ErrorResponse {
352                        error: "read_error".to_string(),
353                        message: format!("Failed to read file contents: {}", e),
354                    }),
355                )
356            })?
357            .to_vec();
358    }
359
360    if contents.is_empty() {
361        tracing::warn!("No file content provided");
362        return Err((
363            StatusCode::BAD_REQUEST,
364            Json(ErrorResponse {
365                error: "empty_file".to_string(),
366                message: "No file content provided".to_string(),
367            }),
368        ));
369    }
370
371    let analysis = analyze_binary(&file_name, &contents, &state.crypto_provider)
372        .await
373        .map_err(|e| {
374            tracing::error!("Binary analysis failed: {}", e);
375            (
376                StatusCode::INTERNAL_SERVER_ERROR,
377                Json(ErrorResponse {
378                    error: "analysis_error".to_string(),
379                    message: format!("Failed to analyze binary: {}", e),
380                }),
381            )
382        })?;
383
384    tracing::info!("Binary analysis complete: {:?}", analysis);
385
386    let matches = scan_binary_vulnerabilities(&analysis);
387    tracing::info!("Vuln scan complete. {} match(es)", matches.len());
388
389    Ok(Json(CveScanResponse { matches }))
390}
391
392// POST /binary/diff - compare two binaries
393pub async fn diff_binaries(
394    State(state): State<AppState>,
395    mut multipart: Multipart,
396) -> Result<Json<Value>, (StatusCode, Json<ErrorResponse>)> {
397    // Extract two files
398    let mut files: Vec<(String, Vec<u8>)> = Vec::new();
399    while let Some(field) = multipart.next_field().await.map_err(|e| {
400        (
401            StatusCode::BAD_REQUEST,
402            Json(ErrorResponse {
403                error: "multipart_error".to_string(),
404                message: format!("Failed parsing multipart: {}", e),
405            }),
406        )
407    })? {
408        let name = field
409            .file_name()
410            .map(|s| s.to_string())
411            .unwrap_or_else(|| "file".to_string());
412        let bytes = field
413            .bytes()
414            .await
415            .map_err(|e| {
416                (
417                    StatusCode::BAD_REQUEST,
418                    Json(ErrorResponse {
419                        error: "read_error".to_string(),
420                        message: format!("Failed to read file: {}", e),
421                    }),
422                )
423            })?
424            .to_vec();
425        files.push((name, bytes));
426    }
427
428    if files.len() != 2 {
429        return Err((
430            StatusCode::BAD_REQUEST,
431            Json(ErrorResponse {
432                error: "invalid_input".to_string(),
433                message: "Exactly two files must be provided".to_string(),
434            }),
435        ));
436    }
437
438    // Analyze each binary to get symbol information
439    let analysis1 = analyze_binary(&files[0].0, &files[0].1, &state.crypto_provider)
440        .await
441        .map_err(|e| {
442            (
443                StatusCode::INTERNAL_SERVER_ERROR,
444                Json(ErrorResponse {
445                    error: "analysis_error".to_string(),
446                    message: format!("Failed to analyze first binary: {}", e),
447                }),
448            )
449        })?;
450
451    let analysis2 = analyze_binary(&files[1].0, &files[1].1, &state.crypto_provider)
452        .await
453        .map_err(|e| {
454            (
455                StatusCode::INTERNAL_SERVER_ERROR,
456                Json(ErrorResponse {
457                    error: "analysis_error".to_string(),
458                    message: format!("Failed to analyze second binary: {}", e),
459                }),
460            )
461        })?;
462
463    use sha2::Digest;
464    use std::collections::HashSet;
465
466    let mut meta = serde_json::Map::new();
467    for (idx, (name, data)) in files.iter().enumerate() {
468        meta.insert(format!("file{}_name", idx + 1), serde_json::json!(name));
469        meta.insert(
470            format!("file{}_size", idx + 1),
471            serde_json::json!(data.len()),
472        );
473        meta.insert(
474            format!("file{}_sha256", idx + 1),
475            serde_json::json!(format!("{:x}", sha2::Sha256::digest(data))),
476        );
477    }
478    meta.insert(
479        "size_diff_bytes".to_string(),
480        serde_json::json!((files[0].1.len() as i64) - (files[1].1.len() as i64)),
481    );
482
483    // Symbol-level diffs
484    let imports1: HashSet<String> = analysis1.imports.iter().cloned().collect();
485    let imports2: HashSet<String> = analysis2.imports.iter().cloned().collect();
486    let exports1: HashSet<String> = analysis1.exports.iter().cloned().collect();
487    let exports2: HashSet<String> = analysis2.exports.iter().cloned().collect();
488    let symbols1: HashSet<String> = analysis1.detected_symbols.iter().cloned().collect();
489    let symbols2: HashSet<String> = analysis2.detected_symbols.iter().cloned().collect();
490
491    let imports_added: Vec<String> = imports2.difference(&imports1).cloned().collect();
492    let imports_removed: Vec<String> = imports1.difference(&imports2).cloned().collect();
493    let exports_added: Vec<String> = exports2.difference(&exports1).cloned().collect();
494    let exports_removed: Vec<String> = exports1.difference(&exports2).cloned().collect();
495    let symbols_added: Vec<String> = symbols2.difference(&symbols1).cloned().collect();
496    let symbols_removed: Vec<String> = symbols1.difference(&symbols2).cloned().collect();
497
498    meta.insert(
499        "imports_added".to_string(),
500        serde_json::json!(imports_added),
501    );
502    meta.insert(
503        "imports_removed".to_string(),
504        serde_json::json!(imports_removed),
505    );
506    meta.insert(
507        "exports_added".to_string(),
508        serde_json::json!(exports_added),
509    );
510    meta.insert(
511        "exports_removed".to_string(),
512        serde_json::json!(exports_removed),
513    );
514    meta.insert(
515        "symbols_added".to_string(),
516        serde_json::json!(symbols_added),
517    );
518    meta.insert(
519        "symbols_removed".to_string(),
520        serde_json::json!(symbols_removed),
521    );
522
523    Ok(Json(meta.into()))
524}
525
526#[axum::debug_handler]
527pub async fn chat_with_binary(
528    State(state): State<AppState>,
529    req: Request,
530) -> Result<Json<ChatResponse>, (StatusCode, Json<ErrorResponse>)> {
531    // Get features from request extensions first (set by middleware)
532    let features = req
533        .extensions()
534        .get::<PlanFeatures>()
535        .cloned()
536        .unwrap_or_else(|| PlanFeatures::default_oss());
537
538    // Extract JSON body manually
539    let (_parts, body) = req.into_parts();
540    let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
541        Ok(bytes) => bytes,
542        Err(_) => {
543            return Err((
544                StatusCode::BAD_REQUEST,
545                Json(ErrorResponse {
546                    error: "invalid_body".to_string(),
547                    message: "Failed to read request body".to_string(),
548                }),
549            ));
550        }
551    };
552
553    let request: ChatRequest = match serde_json::from_slice(&body_bytes) {
554        Ok(req) => req,
555        Err(_) => {
556            return Err((
557                StatusCode::BAD_REQUEST,
558                Json(ErrorResponse {
559                    error: "invalid_json".to_string(),
560                    message: "Invalid JSON in request body".to_string(),
561                }),
562            ));
563        }
564    };
565
566    // Chat is always a paid feature - check the feature flag
567    if !features.chat_enabled {
568        return Err((
569            StatusCode::FORBIDDEN,
570            Json(ErrorResponse {
571                error: "chat_not_available".to_string(),
572                message: "Chat feature is not available with your current license. Please schedule a demo to upgrade your plan: https://cal.com/team/atelier-logos/platform-intro".to_string(),
573            }),
574        ));
575    }
576    // Validate and sanitize the file path using the helper function
577    let canonical_path = validate_file_path(&request.file_path)?;
578
579    // Read the file using the validated canonicalized path
580    let file_content = tokio::fs::read(&canonical_path).await.map_err(|_e| {
581        (
582            StatusCode::BAD_REQUEST,
583            Json(ErrorResponse {
584                error: "file_read_error".to_string(),
585                message: "Failed to read file".to_string(),
586            }),
587        )
588    })?;
589
590    // Extract filename safely from the validated canonical path
591    let file_name = canonical_path
592        .file_name()
593        .and_then(|n| n.to_str())
594        .unwrap_or("unknown")
595        .to_string();
596
597    let analysis = analyze_binary(&file_name, &file_content, &state.crypto_provider)
598        .await
599        .map_err(|e| {
600            (
601                StatusCode::INTERNAL_SERVER_ERROR,
602                Json(ErrorResponse {
603                    error: "analysis_error".to_string(),
604                    message: format!("Failed to analyze binary: {}", e),
605                }),
606            )
607        })?;
608
609    // Store model info before moving values
610    let model_used = request
611        .hf_repo
612        .as_ref()
613        .or(request.model_path.as_ref())
614        .map(|s| s.clone())
615        .unwrap_or_else(|| "unknown".to_string());
616
617    // Handle inference with HTTP provider
618    let response = match request.provider.as_str() {
619        "http" => {
620            let inference_url = request
621                .inference_url
622                .unwrap_or_else(|| "http://localhost:11434".to_string());
623
624            // Validate the inference URL for SSRF protection
625            let ssrf_validator = crate::ssrf_protection::SSRFValidator::new();
626            let validated_url = ssrf_validator.validate_url(&inference_url).map_err(|e| {
627                (
628                    StatusCode::BAD_REQUEST,
629                    Json(ErrorResponse {
630                        error: "ssrf_protection_violation".to_string(),
631                        message: format!("SSRF protection violation: {}", e),
632                    }),
633                )
634            })?;
635
636            let provider =
637                HTTPProvider::new(validated_url.to_string(), None, request.provider_token);
638
639            let mut options = request.options.unwrap_or_default();
640            options.model_path = request.model_path;
641            options.hf_repo = request.hf_repo;
642            // Note: options.model is already set from the request.options if provided
643
644            chat_with_provider(&analysis, &request.question, &provider, &options)
645                .await
646                .map_err(|e| {
647                    (
648                        StatusCode::INTERNAL_SERVER_ERROR,
649                        Json(ErrorResponse {
650                            error: "inference_error".to_string(),
651                            message: format!("Failed to chat with binary: {}", e),
652                        }),
653                    )
654                })?
655        }
656        _ => {
657            return Err((
658                StatusCode::BAD_REQUEST,
659                Json(ErrorResponse {
660                    error: "invalid_provider".to_string(),
661                    message: "Provider must be 'http'".to_string(),
662                }),
663            ));
664        }
665    };
666
667    Ok(Json(ChatResponse {
668        answer: response.text,
669        model_used,
670        tokens_used: response.tokens_used,
671    }))
672}
673
674async fn chat_with_provider(
675    analysis: &BinaryAnalysis,
676    user_question: &str,
677    provider: &dyn InferenceProvider,
678    options: &GenerationOptions,
679) -> Result<GenerationResponse, anyhow::Error> {
680    // Check if the question asks for JSON output
681    let is_json_request = user_question.to_lowercase().contains("json")
682        || user_question.to_lowercase().contains("sbom")
683        || user_question.to_lowercase().contains("cyclonedx");
684
685    let context = if is_json_request {
686        format!(
687            "Binary Analysis Context:\n\
688             - File: {}\n\
689             - Format: {}\n\
690             - Architecture: {}\n\
691             - Size: {} bytes\n\
692             - Linked Libraries: {}\n\
693             - Imports: {}\n\
694             - Exports: {}\n\
695             - Embedded Strings: {}\n\n\
696             User Question: {}\n\n\
697             CRITICAL: You must return ONLY raw JSON. Do NOT wrap it in quotes or escape it as a string. Return the actual JSON object directly. Do not include any explanations, markdown, or code blocks. The response should start with {{ and end with }}.",
698            analysis.file_name,
699            analysis.format,
700            analysis.architecture,
701            analysis.size_bytes,
702            analysis.linked_libraries.join(", "),
703            analysis.imports.join(", "),
704            analysis.exports.join(", "),
705            analysis.embedded_strings.join(", "),
706            user_question
707        )
708    } else {
709        format!(
710            "Binary Analysis Context:\n\
711         - File: {}\n\
712         - Format: {}\n\
713         - Architecture: {}\n\
714         - Size: {} bytes\n\
715         - Linked Libraries: {}\n\
716         - Imports: {}\n\
717         - Exports: {}\n\
718         - Embedded Strings: {}\n\n\
719         User Question: {}\n\n\
720         Please provide a helpful answer about this binary based on the analysis data.",
721            analysis.file_name,
722            analysis.format,
723            analysis.architecture,
724            analysis.size_bytes,
725            analysis.linked_libraries.join(", "),
726            analysis.imports.join(", "),
727            analysis.exports.join(", "),
728            analysis.embedded_strings.join(", "),
729            user_question
730        )
731    };
732
733    let mut response = provider
734        .generate(&context, options)
735        .await
736        .map_err(|e| anyhow::anyhow!("Inference failed: {}", e))?;
737
738    // Post-process JSON responses to handle cases where the model returns JSON as a string
739    if is_json_request {
740        let text = response.text.trim();
741        // If the response looks like a JSON string (starts and ends with quotes), try to parse it
742        if text.starts_with('"') && text.ends_with('"') {
743            if let Ok(parsed_json) = serde_json::from_str::<serde_json::Value>(text) {
744                response.text = serde_json::to_string(&parsed_json)
745                    .map_err(|e| anyhow::anyhow!("Failed to serialize JSON: {}", e))?;
746            }
747        }
748    }
749
750    Ok(response)
751}