use std::sync::{ Arc, Mutex };
use std::time::Duration;
use api_gemini::
{
client ::Client,
};
use tracing::Level;
use tracing_subscriber::
{
fmt ::{ self, format::FmtSpan },
Registry,
EnvFilter,
};
use core::cell::RefCell;
use tracing_subscriber::layer::Layer;
use tracing::{ Event, Subscriber, Instrument };
#[ derive( Debug, Clone ) ]
pub struct LogEntry
{
pub level: Level,
pub message: String,
pub target: String,
pub fields: std::collections::HashMap< String, String >,
pub timestamp: std::time::SystemTime,
}
#[ derive( Debug ) ]
#[ allow( dead_code ) ]
struct CaptureLayer;
impl CaptureLayer
{
#[ allow( dead_code ) ]
fn new() -> Self
{
Self
}
}
impl< S > Layer< S > for CaptureLayer
where
S: Subscriber,
{
fn on_event( &self, event: &Event< '_ >, _ctx: tracing_subscriber::layer::Context< '_, S > )
{
let mut fields = std::collections::HashMap::new();
let mut message = String::new();
let mut visitor = FieldVisitor { fields : &mut fields, message : &mut message };
event.record( &mut visitor );
let entry = LogEntry {
level: *event.metadata().level(),
message: message.clone(),
target: event.metadata().target().to_string(),
fields,
timestamp: std::time::SystemTime::now(),
};
TEST_CAPTURE.with( |logs| logs.borrow_mut().push( entry ) );
}
}
#[ allow( dead_code ) ]
struct FieldVisitor< 'a >
{
fields: &'a mut std::collections::HashMap< String, String >,
message: &'a mut String,
}
impl tracing::field::Visit for FieldVisitor< '_ >
{
fn record_debug( &mut self, field: &tracing::field::Field, value: &dyn core::fmt::Debug )
{
if field.name() == "message"
{
*self.message = format!( "{value:?}" );
} else {
self.fields.insert( field.name().to_string(), format!( "{value:?}" ) );
}
}
fn record_str( &mut self, field: &tracing::field::Field, value: &str )
{
if field.name() == "message"
{
*self.message = value.to_string();
} else {
self.fields.insert( field.name().to_string(), value.to_string() );
}
}
fn record_f64( &mut self, field: &tracing::field::Field, value: f64 )
{
self.fields.insert( field.name().to_string(), value.to_string() );
}
fn record_u64( &mut self, field: &tracing::field::Field, value: u64 )
{
self.fields.insert( field.name().to_string(), value.to_string() );
}
fn record_i64( &mut self, field: &tracing::field::Field, value: i64 )
{
self.fields.insert( field.name().to_string(), value.to_string() );
}
}
#[ derive( Debug ) ]
pub struct TestLogCapture
{
pub entries: Arc< Mutex< Vec< LogEntry > > >,
}
impl TestLogCapture
{
#[ must_use ]
pub fn new() -> ( Self, Arc< Mutex< Vec< LogEntry > > > )
{
let entries = Arc::new( Mutex::new( Vec::new() ) );
let capture = Self {
entries: entries.clone(),
};
( capture, entries )
}
pub fn clear( &self )
{
self.entries.lock().unwrap().clear();
}
}
#[ allow( dead_code ) ]
fn create_logging_client() -> Client
{
std ::env::set_var( "GEMINI_ENABLE_HTTP_LOGGING", "1" );
Client::new()
.expect( "Failed to create client for logging tests" )
}
#[ tokio::test ]
#[ cfg( feature = "logging" ) ]
async fn test_http_request_logging_basic()
{
let _guard = setup_test_logging();
let client = create_logging_client();
let models_api = client.models();
let result = models_api.list().await;
match result
{
Ok( models ) =>
{
assert!( !models.models.is_empty() );
let logs = get_captured_logs();
let start_log = logs.iter().find( |entry|
entry.message.contains( "Starting HTTP request" ) &&
entry.fields.contains_key( "url" ) &&
entry.fields.contains_key( "method" ) &&
entry.fields.contains_key( "request_id" )
);
assert!( start_log.is_some(), "Missing structured request start log" );
let success_log = logs.iter().find( |entry|
entry.message.contains( "HTTP request completed successfully" ) &&
entry.fields.contains_key( "duration_ms" ) &&
entry.fields.contains_key( "status_code" ) &&
entry.fields.contains_key( "response_size_bytes" )
);
assert!( success_log.is_some(), "Missing structured success log" );
},
Err( e ) => panic!( "HTTP request failed : {e}" ),
}
}
#[ tokio::test ]
#[ cfg( feature = "logging" ) ]
async fn test_error_logging_structured()
{
let _guard = setup_test_logging();
let client = create_logging_client();
let models_api = client.models();
let result = models_api.get( "models/non-existent-model" ).await;
match result
{
Err( _error ) =>
{
let logs = get_captured_logs();
let error_log = logs.iter().find( |entry|
entry.level == Level::ERROR &&
entry.fields.contains_key( "error_type" ) &&
entry.fields.contains_key( "error_message" ) &&
entry.fields.contains_key( "url" ) &&
entry.fields.contains_key( "duration_ms" )
);
assert!( error_log.is_some(), "Missing structured error log : {logs:?}" );
let error_entry = error_log.unwrap();
assert!( error_entry.fields.get( "error_type" ).unwrap().contains( "ApiError" ) );
},
Ok( _ ) => panic!( "Expected error for non-existent model" ),
}
}
#[ tokio::test ]
#[ cfg( feature = "logging" ) ]
async fn test_performance_monitoring_logging()
{
let _guard = setup_test_logging();
let client = create_logging_client();
let models_api = client.models();
let model = models_api.by_name( "text-embedding-004" );
let result = model.embed_text( "Performance monitoring test" ).await;
match result
{
Ok( embedding ) =>
{
assert!( !embedding.is_empty() );
let logs = get_captured_logs();
let perf_logs: Vec< _ > = logs.iter().filter( |entry|
entry.fields.contains_key( "duration_ms" ) &&
entry.fields.contains_key( "operation" )
).collect();
assert!( !perf_logs.is_empty(), "Missing performance monitoring logs" );
for log in perf_logs
{
let duration_str = log.fields.get( "duration_ms" ).unwrap();
let duration: f64 = duration_str.parse().unwrap();
assert!( (0.0..30000.0).contains(&duration), "Invalid duration : {duration}" );
}
},
Err( e ) => panic!( "Embed text failed : {e}" ),
}
}
#[ tokio::test ]
#[ cfg( feature = "logging" ) ]
async fn test_log_level_filtering()
{
let _guard = setup_test_logging_with_level( Level::INFO );
let client = create_logging_client();
let models_api = client.models();
let _ = models_api.list().await;
let logs = get_captured_logs();
let has_info = logs.iter().any( |entry| entry.level == Level::INFO );
let has_debug = logs.iter().any( |entry| entry.level == Level::DEBUG );
let has_error = logs.iter().any( |entry| entry.level == Level::ERROR );
assert!( has_info || has_error, "Missing INFO/ERROR level logs" );
assert!( !has_debug, "DEBUG logs should be filtered out at INFO level" );
}
#[ tokio::test ]
#[ cfg( all( feature = "logging", feature = "streaming" ) ) ]
async fn test_streaming_logging()
{
let _guard = setup_test_logging();
let client = create_logging_client();
let models_api = client.models();
let model = models_api.by_name( "gemini-1.5-pro" );
let result = model.generate_text( "Count from 1 to 3" ).await;
match result
{
Ok( text ) =>
{
assert!( !text.is_empty() );
let logs = get_captured_logs();
let request_logs: Vec< _ > = logs.iter().filter( |entry|
entry.fields.contains_key( "operation" ) ||
entry.fields.contains_key( "request_id" ) ||
entry.message.contains( "HTTP request" )
).collect();
assert!( !request_logs.is_empty(), "Missing HTTP request logs" );
},
Err( e ) =>
{
println!( "Streaming not implemented yet : {e}" );
}
}
}
#[ tokio::test ]
#[ cfg( feature = "logging" ) ]
async fn test_batch_operations_logging()
{
let _guard = setup_test_logging();
let client = create_logging_client();
let models_api = client.models();
let model = models_api.by_name( "text-embedding-004" );
let texts = vec![
"Batch logging test 1",
"Batch logging test 2",
"Batch logging test 3",
];
let result = model.batch_embed_texts( &texts ).await;
match result
{
Ok( embeddings ) =>
{
assert_eq!( embeddings.len(), texts.len() );
let logs = get_captured_logs();
let batch_logs: Vec< _ > = logs.iter().filter( |entry|
entry.fields.contains_key( "batch_id" ) ||
entry.fields.contains_key( "batch_size" ) ||
entry.message.contains( "batch" )
).collect();
assert!( !batch_logs.is_empty(), "Missing batch operation logs" );
if let Some( first_log ) = batch_logs.first()
{
if let Some( batch_id ) = first_log.fields.get( "batch_id" )
{
let same_batch_id = batch_logs.iter().all( |log|
log.fields.get( "batch_id" ) == Some(batch_id)
);
assert!( same_batch_id, "Batch correlation ID should be consistent" );
}
}
},
Err( e ) => panic!( "Batch embed failed : {e}" ),
}
}
#[ tokio::test ]
#[ cfg( feature = "logging" ) ]
async fn test_sensitive_data_redaction()
{
let _guard = setup_test_logging();
let client = create_logging_client();
let models_api = client.models();
let _ = models_api.list().await;
let logs = get_captured_logs();
for log in logs
{
assert!( !log.message.contains( "AIza" ), "API key leaked in log message" );
for value in log.fields.values()
{
assert!( !value.contains( "AIza" ), "API key leaked in log field : {value}" );
}
}
}
#[ tokio::test ]
#[ cfg( feature = "logging" ) ]
async fn test_span_context_propagation()
{
let _guard = setup_test_logging();
let operation_span = tracing::info_span!(
"embedding_operation",
operation_id = "test-123",
user_context = "integration_test"
);
let result = async {
let client = create_logging_client();
let models_api = client.models();
let model = models_api.by_name( "text-embedding-004" );
model.embed_text( "Context propagation test" ).await
}.instrument( operation_span ).await;
match result
{
Ok( embedding ) =>
{
assert!( !embedding.is_empty() );
let logs = get_captured_logs();
let request_logs: Vec< _ > = logs.iter().filter( |entry|
entry.fields.contains_key( "operation" ) ||
entry.fields.contains_key( "request_id" ) ||
entry.message.contains( "HTTP request" )
).collect();
assert!( !request_logs.is_empty(), "Missing HTTP request logs for operation" );
},
Err( e ) => panic!( "Context propagation test failed : {e}" ),
}
}
#[ tokio::test ]
#[ cfg( feature = "logging" ) ]
async fn test_log_sampling()
{
let _guard = setup_test_logging();
let client = create_logging_client();
let models_api = client.models();
let mut results = Vec::new();
for _i in 0..10
{
let result = models_api.list().await;
results.push( result );
tokio ::time::sleep( Duration::from_millis( 10 ) ).await;
}
for result in results
{
assert!( result.is_ok() );
}
let logs = get_captured_logs();
let request_logs: Vec< _ > = logs.iter().filter( |entry|
entry.message.contains( "HTTP request" )
).collect();
assert!( !request_logs.is_empty(), "Should have some request logs" );
assert!( request_logs.len() <= 35, "Too many logs - sampling may be needed : found {}", request_logs.len() );
}
thread_local! {
static TEST_CAPTURE: RefCell< Vec< LogEntry > > = const { RefCell::new( Vec::new() ) };
}
#[ allow( dead_code ) ]
fn setup_test_logging() -> tracing::subscriber::DefaultGuard
{
setup_test_logging_with_level( Level::DEBUG )
}
#[ allow( dead_code ) ]
fn setup_test_logging_with_level( level: Level ) -> tracing::subscriber::DefaultGuard
{
use tracing_subscriber::layer::SubscriberExt;
TEST_CAPTURE.with( |logs| logs.borrow_mut().clear() );
let capture_layer = CaptureLayer::new();
let subscriber = Registry::default()
.with( EnvFilter::from_default_env()
.add_directive( format!( "api_gemini={level}" ).parse().unwrap() )
)
.with( capture_layer )
.with( fmt::layer()
.with_test_writer()
.with_target( true )
.with_span_events( FmtSpan::CLOSE )
);
tracing ::subscriber::set_default( subscriber )
}
#[ allow( dead_code ) ]
fn get_captured_logs() -> Vec< LogEntry >
{
TEST_CAPTURE.with( |logs| logs.borrow().clone() )
}
#[ allow( dead_code ) ]
fn simulate_http_log( level: Level, message: &str, fields: std::collections::HashMap< String, String > )
{
let entry = LogEntry {
level,
message: message.to_string(),
target: "api_gemini::internal::http".to_string(),
fields,
timestamp: std::time::SystemTime::now(),
};
TEST_CAPTURE.with( |logs| logs.borrow_mut().push( entry ) );
}