use serde_json::Value;
use std::fs;
use std::path::PathBuf;
use crate::error::RjdError;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum SymlinkPolicy {
Reject,
Follow,
}
fn check_json_depth(value: &Value, max_depth: usize) -> Result<(), usize> {
fn check_depth(value: &Value, current_depth: usize, max_depth: usize) -> Result<(), usize> {
if current_depth > max_depth {
return Err(current_depth);
}
match value {
Value::Object(map) => {
for (_, v) in map {
check_depth(v, current_depth + 1, max_depth)?;
}
}
Value::Array(arr) => {
for v in arr {
check_depth(v, current_depth + 1, max_depth)?;
}
}
_ => {}
}
Ok(())
}
check_depth(value, 1, max_depth)
}
fn parse_json_with_depth_limit(content: &str, max_depth: usize) -> Result<Value, String> {
let value: Value =
serde_json::from_str(content).map_err(|e| format!("Failed to parse JSON: {}", e))?;
check_json_depth(&value, max_depth)
.map_err(|depth| format!("JSON depth {} exceeds limit {}", depth, max_depth))?;
Ok(value)
}
const DEFAULT_MAX_FILE_SIZE: u64 = 100 * 1024 * 1024;
const DEFAULT_MAX_JSON_DEPTH: usize = 1000;
#[derive(Debug, Clone, Copy)]
pub struct LoadConfig {
pub max_file_size: u64,
pub max_json_depth: usize,
}
impl LoadConfig {
pub fn from_env() -> Self {
Self {
max_file_size: std::env::var("RJD_MAX_FILE_SIZE")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(DEFAULT_MAX_FILE_SIZE),
max_json_depth: std::env::var("RJD_MAX_JSON_DEPTH")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(DEFAULT_MAX_JSON_DEPTH),
}
}
}
impl Default for LoadConfig {
fn default() -> Self {
Self {
max_file_size: DEFAULT_MAX_FILE_SIZE,
max_json_depth: DEFAULT_MAX_JSON_DEPTH,
}
}
}
impl LoadConfig {
pub fn with_limits(max_file_size: u64, max_json_depth: usize) -> Self {
Self {
max_file_size,
max_json_depth,
}
}
pub fn merge_with_cli(&self, max_file_size: Option<u64>, max_depth: Option<usize>) -> Self {
Self {
max_file_size: max_file_size.unwrap_or(self.max_file_size),
max_json_depth: max_depth.unwrap_or(self.max_json_depth),
}
}
}
pub fn load_json_file(path: &PathBuf) -> Result<Value, RjdError> {
load_json_file_with_config(path, &LoadConfig::default())
}
pub fn load_json_file_with_config(path: &PathBuf, config: &LoadConfig) -> Result<Value, RjdError> {
load_json_file_with_config_and_policy(path, config, SymlinkPolicy::Reject)
}
pub fn load_json_file_with_config_and_policy(
path: &PathBuf,
config: &LoadConfig,
policy: SymlinkPolicy,
) -> Result<Value, RjdError> {
if !path.exists() {
return Err(RjdError::FileRead {
path: path.clone(),
source: std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("File not found: {}", path.display()),
),
});
}
let metadata = path
.symlink_metadata()
.map_err(|source| RjdError::FileRead {
path: path.clone(),
source,
})?;
if metadata.is_symlink() {
match policy {
SymlinkPolicy::Reject => {
return Err(RjdError::SymlinkRejected { path: path.clone() });
}
SymlinkPolicy::Follow => {
let canonical = path.canonicalize().map_err(|source| RjdError::FileRead {
path: path.clone(),
source,
})?;
return load_json_file_with_config_and_policy(&canonical, config, policy);
}
}
}
if !path.is_file() {
return Err(RjdError::FileRead {
path: path.clone(),
source: std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Not a file: {}", path.display()),
),
});
}
let metadata = fs::metadata(path).map_err(|source| RjdError::FileRead {
path: path.clone(),
source,
})?;
let file_size = metadata.len();
if file_size > config.max_file_size {
return Err(RjdError::FileTooLarge {
path: path.clone(),
size: file_size,
limit: config.max_file_size,
});
}
let content = fs::read_to_string(path).map_err(|source| RjdError::FileRead {
path: path.clone(),
source,
})?;
let value = parse_json_with_depth_limit(&content, config.max_json_depth).map_err(|msg| {
RjdError::JsonParse {
path: path.clone(),
source: serde_json::Error::io(std::io::Error::other(msg)),
}
})?;
Ok(value)
}
pub fn load_json_input(input: &str) -> Result<Value, RjdError> {
load_json_input_with_config(input, &LoadConfig::default())
}
pub fn load_json_input_with_config(input: &str, config: &LoadConfig) -> Result<Value, RjdError> {
load_json_input_with_config_and_policy(input, config, SymlinkPolicy::Reject)
}
pub fn load_json_input_with_config_and_policy(
input: &str,
config: &LoadConfig,
policy: SymlinkPolicy,
) -> Result<Value, RjdError> {
load_json_input_with_config_policy_and_inline(input, config, policy, false)
}
pub fn load_json_input_with_config_policy_and_inline(
input: &str,
config: &LoadConfig,
policy: SymlinkPolicy,
force_inline: bool,
) -> Result<Value, RjdError> {
let trimmed = input.trim();
if force_inline {
return serde_json::from_str(input).map_err(|_| RjdError::InvalidInput {
input: input.to_string(),
});
}
if trimmed.starts_with('{') || trimmed.starts_with('[') {
return parse_json_with_depth_limit(input, config.max_json_depth).map_err(|_msg| {
RjdError::InvalidInput {
input: input.to_string(),
}
});
}
let path = PathBuf::from(input);
if path.exists() {
return load_json_file_with_config_and_policy(&path, config, policy);
}
parse_json_with_depth_limit(input, config.max_json_depth).map_err(|_| RjdError::InvalidInput {
input: input.to_string(),
})
}
pub fn load_json_stdin() -> Result<Value, RjdError> {
load_json_stdin_with_config(&LoadConfig::default())
}
pub fn load_json_stdin_with_config(config: &LoadConfig) -> Result<Value, RjdError> {
let content =
std::io::read_to_string(std::io::stdin()).map_err(|source| RjdError::Internal {
message: format!("Failed to read from stdin: {}", source),
})?;
parse_json_with_depth_limit(&content, config.max_json_depth).map_err(|msg| RjdError::Internal {
message: format!("Failed to parse JSON from stdin: {}", msg),
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
use tempfile::NamedTempFile;
#[test]
fn test_load_valid_json() {
let temp_file = NamedTempFile::new().unwrap();
let file_path = temp_file.path().to_path_buf();
drop(temp_file);
std::fs::write(&file_path, r#"{"name": "test", "value": 42}"#).unwrap();
let result = load_json_file(&file_path);
assert!(result.is_ok());
let value = result.unwrap();
assert_eq!(value["name"], "test");
assert_eq!(value["value"], 42);
}
#[test]
fn test_load_nonexistent_file() {
let result = load_json_file(&PathBuf::from("/nonexistent/file.json"));
assert!(result.is_err());
}
#[test]
fn test_load_invalid_json() {
let temp_file = NamedTempFile::new().unwrap();
let file_path = temp_file.path().to_path_buf();
drop(temp_file);
std::fs::write(&file_path, r#"{"invalid": json}"#).unwrap();
let result = load_json_file(&file_path);
assert!(result.is_err());
}
#[test]
fn test_load_json_input_inline_json() {
let result = load_json_input(r#"{"name": "test", "value": 42}"#);
assert!(result.is_ok());
let value = result.unwrap();
assert_eq!(value["name"], "test");
assert_eq!(value["value"], 42);
}
#[test]
fn test_load_json_input_file() {
let temp_file = NamedTempFile::new().unwrap();
let file_path = temp_file.path().to_path_buf();
drop(temp_file);
std::fs::write(&file_path, r#"{"name": "test", "value": 42}"#).unwrap();
let result = load_json_input(&file_path.to_string_lossy());
assert!(result.is_ok());
let value = result.unwrap();
assert_eq!(value["name"], "test");
assert_eq!(value["value"], 42);
}
#[test]
fn test_load_json_input_simple_values_as_file() {
assert!(load_json_input("42").unwrap().is_number());
assert!(load_json_input(r#""hello""#).unwrap().is_string());
assert!(load_json_input("true").unwrap().is_boolean());
assert!(load_json_input("null").unwrap().is_null());
}
#[test]
fn test_load_json_input_objects_and_arrays() {
assert!(load_json_input("{}").unwrap().is_object());
assert!(load_json_input("[]").unwrap().is_array());
assert!(load_json_input(r#"{"name": "test"}"#).unwrap().is_object());
assert!(load_json_input(r#"[1, 2, 3]"#).unwrap().is_array());
}
#[test]
fn test_load_config_default() {
let config = LoadConfig::default();
assert_eq!(config.max_file_size, DEFAULT_MAX_FILE_SIZE);
assert_eq!(config.max_json_depth, DEFAULT_MAX_JSON_DEPTH);
}
#[test]
fn test_load_config_with_limits() {
let config = LoadConfig::with_limits(500, 100);
assert_eq!(config.max_file_size, 500);
assert_eq!(config.max_json_depth, 100);
}
#[test]
fn test_load_config_from_env_no_env_vars() {
std::env::remove_var("RJD_MAX_FILE_SIZE");
std::env::remove_var("RJD_MAX_JSON_DEPTH");
let config = LoadConfig::from_env();
assert_eq!(config.max_file_size, DEFAULT_MAX_FILE_SIZE);
assert_eq!(config.max_json_depth, DEFAULT_MAX_JSON_DEPTH);
}
#[test]
fn test_load_config_from_env_with_vars() {
std::env::set_var("RJD_MAX_FILE_SIZE", "200000000");
std::env::set_var("RJD_MAX_JSON_DEPTH", "500");
let config = LoadConfig::from_env();
assert_eq!(config.max_file_size, 200000000);
assert_eq!(config.max_json_depth, 500);
std::env::remove_var("RJD_MAX_FILE_SIZE");
std::env::remove_var("RJD_MAX_JSON_DEPTH");
}
#[test]
fn test_load_config_from_env_with_invalid_vars() {
std::env::set_var("RJD_MAX_FILE_SIZE", "invalid");
std::env::set_var("RJD_MAX_JSON_DEPTH", "not_a_number");
let config = LoadConfig::from_env();
assert_eq!(config.max_file_size, DEFAULT_MAX_FILE_SIZE);
assert_eq!(config.max_json_depth, DEFAULT_MAX_JSON_DEPTH);
std::env::remove_var("RJD_MAX_FILE_SIZE");
std::env::remove_var("RJD_MAX_JSON_DEPTH");
}
#[test]
fn test_reject_file_over_limit() {
let temp_file = NamedTempFile::new().unwrap();
let file_path = temp_file.path().to_path_buf();
drop(temp_file);
std::fs::write(&file_path, r#"{"test": "data"}"#).unwrap();
let config = LoadConfig::with_limits(5, 1000); let result = load_json_file_with_config(&file_path, &config);
assert!(result.is_err());
match result {
Err(RjdError::FileTooLarge { size, limit, .. }) => {
assert_eq!(limit, 5);
assert!(size > 5);
}
_ => panic!("Expected FileTooLarge error"),
}
}
#[test]
fn test_accept_file_under_limit() {
let temp_file = NamedTempFile::new().unwrap();
let file_path = temp_file.path().to_path_buf();
drop(temp_file);
std::fs::write(&file_path, r#"{"test": "data"}"#).unwrap();
let config = LoadConfig::with_limits(1000, 1000);
let result = load_json_file_with_config(&file_path, &config);
assert!(result.is_ok());
let value = result.unwrap();
assert_eq!(value["test"], "data");
}
#[test]
fn test_file_metadata_unavailable() {
let file_path = PathBuf::from("/proc/some_nonexistent_file");
let config = LoadConfig::default();
let result = load_json_file_with_config(&file_path, &config);
assert!(result.is_err());
match result {
Err(RjdError::FileRead { .. }) => (),
_ => panic!("Expected FileRead error for unavailable metadata"),
}
}
#[test]
fn test_reject_json_over_depth_limit() {
let nested = r#"{"a": {"b": {"c": {"d": {"e": "value"}}}}}"#;
let temp_file = NamedTempFile::new().unwrap();
let file_path = temp_file.path().to_path_buf();
drop(temp_file);
std::fs::write(&file_path, nested).unwrap();
let config = LoadConfig::with_limits(1000, 3);
let result = load_json_file_with_config(&file_path, &config);
assert!(result.is_err());
match result {
Err(RjdError::JsonParse { .. }) => {
}
_ => panic!("Expected JsonParse error for depth exceeded"),
}
}
#[test]
fn test_accept_json_under_depth_limit() {
let nested = r#"{"a": {"b": {"c": {"d": {"e": "value"}}}}}"#;
let temp_file = NamedTempFile::new().unwrap();
let file_path = temp_file.path().to_path_buf();
drop(temp_file);
std::fs::write(&file_path, nested).unwrap();
let config = LoadConfig::with_limits(1000, 10);
let result = load_json_file_with_config(&file_path, &config);
assert!(result.is_ok());
let value = result.unwrap();
assert_eq!(value["a"]["b"]["c"]["d"]["e"], "value");
}
#[test]
fn test_depth_limit_with_nested_arrays() {
let nested = r#"[[[[[[42]]]]]]"#;
let temp_file = NamedTempFile::new().unwrap();
let file_path = temp_file.path().to_path_buf();
drop(temp_file);
std::fs::write(&file_path, nested).unwrap();
let config = LoadConfig::with_limits(1000, 5);
let result = load_json_file_with_config(&file_path, &config);
assert!(result.is_err());
}
#[test]
fn test_depth_limit_with_nested_objects() {
let nested = r#"{"l1":{"l2":{"l3":{"l4":{"l5":"deep"}}}}}"#;
let temp_file = NamedTempFile::new().unwrap();
let file_path = temp_file.path().to_path_buf();
drop(temp_file);
std::fs::write(&file_path, nested).unwrap();
let config = LoadConfig::with_limits(1000, 10);
let result = load_json_file_with_config(&file_path, &config);
assert!(result.is_ok());
let value = result.unwrap();
assert_eq!(value["l1"]["l2"]["l3"]["l4"]["l5"], "deep");
}
#[test]
fn test_symlink_policy_enum() {
let reject = SymlinkPolicy::Reject;
let follow = SymlinkPolicy::Follow;
assert_eq!(reject, SymlinkPolicy::Reject);
assert_eq!(follow, SymlinkPolicy::Follow);
assert_ne!(reject, follow);
}
#[test]
fn test_reject_symlink_by_default() {
let temp_file = NamedTempFile::new().unwrap();
let file_path = temp_file.path().to_path_buf();
std::fs::write(&file_path, r#"{"test": "data"}"#).unwrap();
let symlink_dir = tempfile::tempdir().unwrap();
let symlink_path = symlink_dir.path().join("symlink.json");
#[cfg(unix)]
std::os::unix::fs::symlink(&file_path, &symlink_path).unwrap();
#[cfg(windows)]
std::os::windows::fs::symlink_file(&file_path, &symlink_path).unwrap();
let config = LoadConfig::default();
let result = load_json_file_with_config(&symlink_path, &config);
assert!(result.is_err());
match result {
Err(RjdError::SymlinkRejected { .. }) => {}
_ => panic!("Expected SymlinkRejected error"),
}
}
#[test]
fn test_follow_symlink_with_policy() {
let temp_file = NamedTempFile::new().unwrap();
let file_path = temp_file.path().to_path_buf();
std::fs::write(&file_path, r#"{"test": "data"}"#).unwrap();
let symlink_dir = tempfile::tempdir().unwrap();
let symlink_path = symlink_dir.path().join("symlink.json");
#[cfg(unix)]
std::os::unix::fs::symlink(&file_path, &symlink_path).unwrap();
#[cfg(windows)]
std::os::windows::fs::symlink_file(&file_path, &symlink_path).unwrap();
let config = LoadConfig::default();
let result =
load_json_file_with_config_and_policy(&symlink_path, &config, SymlinkPolicy::Follow);
assert!(result.is_ok());
let value = result.unwrap();
assert_eq!(value["test"], "data");
}
#[test]
fn test_symlink_to_nonexistent_target() {
let symlink_dir = tempfile::tempdir().unwrap();
let symlink_path = symlink_dir.path().join("broken_symlink.json");
let nonexistent_path = PathBuf::from("/nonexistent/file.json");
#[cfg(unix)]
std::os::unix::fs::symlink(&nonexistent_path, &symlink_path).unwrap();
#[cfg(windows)]
std::os::windows::fs::symlink_file(&nonexistent_path, &symlink_path).unwrap();
let config = LoadConfig::default();
let result =
load_json_file_with_config_and_policy(&symlink_path, &config, SymlinkPolicy::Follow);
assert!(result.is_err());
}
#[test]
#[cfg(unix)] fn test_circular_symlink_detection() {
let temp_dir = tempfile::tempdir().unwrap();
let link1 = temp_dir.path().join("link1");
let link2 = temp_dir.path().join("link2");
std::os::unix::fs::symlink(&link2, &link1).unwrap();
std::os::unix::fs::symlink(&link1, &link2).unwrap();
let config = LoadConfig::default();
let result = load_json_file_with_config_and_policy(&link1, &config, SymlinkPolicy::Follow);
assert!(result.is_err() || link1.canonicalize().is_err());
}
}