use tari_common_sqlite::connection::DbConnection;
use tari_shutdown::Shutdown;
use tari_test_utils::unpack_enum;
use crate::{
CommsBuilder,
peer_manager::database::{MIGRATIONS, PeerDatabaseSql},
protocol::rpc::{
RpcError,
RpcServer,
RpcStatus,
RpcStatusCode,
test::mock::{MockRpcClient, MockRpcService},
},
test_utils::node_identity::build_node_identity,
transports::MemoryTransport,
};
#[tokio::test]
async fn run_service() {
let node_identity1 = build_node_identity(Default::default());
let rpc_service = MockRpcService::new();
let mock_state = rpc_service.shared_state();
let shutdown = Shutdown::new();
let db_connection = DbConnection::connect_temp_file_and_migrate(MIGRATIONS).unwrap();
let peers_db = PeerDatabaseSql::new(db_connection, &node_identity1.to_peer()).unwrap();
let comms1 = CommsBuilder::new()
.with_listener_address(node_identity1.first_public_address().unwrap())
.with_node_identity(node_identity1)
.with_shutdown_signal(shutdown.to_signal())
.with_peer_storage(peers_db)
.build()
.unwrap()
.add_rpc_server(RpcServer::new().add_service(rpc_service))
.spawn_with_transport(MemoryTransport)
.await
.unwrap();
let node_identity2 = build_node_identity(Default::default());
let db_connection = DbConnection::connect_temp_file_and_migrate(MIGRATIONS).unwrap();
let peers_db = PeerDatabaseSql::new(db_connection, &node_identity2.to_peer()).unwrap();
let comms2 = CommsBuilder::new()
.with_listener_address(node_identity2.first_public_address().unwrap())
.with_shutdown_signal(shutdown.to_signal())
.with_node_identity(node_identity2.clone())
.with_peer_storage(peers_db)
.build()
.unwrap();
comms2
.peer_manager()
.add_or_update_peer(comms1.node_identity().to_peer())
.await
.unwrap();
let comms2 = comms2.spawn_with_transport(MemoryTransport).await.unwrap();
let mut conn = comms2
.connectivity()
.dial_peer(comms1.node_identity().node_id().clone())
.await
.unwrap();
let mut client = conn.connect_rpc::<MockRpcClient>().await.unwrap();
mock_state.set_response_ok(&());
client.request_response::<_, ()>((), 0.into()).await.unwrap();
assert_eq!(mock_state.call_count(), 1);
mock_state.set_response_err(RpcStatus::bad_request("Insert 💾"));
let err = client.request_response::<_, ()>((), 0.into()).await.unwrap_err();
unpack_enum!(RpcError::RequestFailed(status) = err);
unpack_enum!(RpcStatusCode::BadRequest = status.as_status_code());
assert_eq!(mock_state.call_count(), 2);
}