use nostr_sdk::prelude::*;
use std::sync::Arc;
use crate::core::constants::*;
use crate::core::error::{Error, Result};
use crate::core::serializers;
use crate::core::types::{EncryptionMode, JsonRpcMessage};
use crate::core::validation;
use crate::encryption;
use crate::relay::RelayPoolTrait;
const LOG_TARGET: &str = "contextvm_sdk::transport::base";
pub struct BaseTransport {
pub relay_pool: Arc<dyn RelayPoolTrait>,
pub encryption_mode: EncryptionMode,
pub is_connected: bool,
}
impl BaseTransport {
pub async fn connect(&mut self, relay_urls: &[String]) -> Result<()> {
if self.is_connected {
return Ok(());
}
self.relay_pool.connect(relay_urls).await?;
self.is_connected = true;
Ok(())
}
pub async fn disconnect(&mut self) -> Result<()> {
if !self.is_connected {
return Ok(());
}
self.relay_pool.disconnect().await?;
self.is_connected = false;
Ok(())
}
pub async fn get_public_key(&self) -> Result<PublicKey> {
self.relay_pool.public_key().await
}
pub async fn subscribe_for_pubkey(&self, pubkey: &PublicKey) -> Result<()> {
let p_tag = pubkey.to_hex();
let now = Timestamp::now();
let ephemeral_filter = Filter::new()
.kind(Kind::Custom(CTXVM_MESSAGES_KIND))
.custom_tag(SingleLetterTag::lowercase(Alphabet::P), p_tag.clone())
.since(now);
let gift_wrap_filter = Filter::new()
.kind(Kind::Custom(GIFT_WRAP_KIND))
.custom_tag(SingleLetterTag::lowercase(Alphabet::P), p_tag.clone())
.since(now);
let ephemeral_gift_wrap_filter = Filter::new()
.kind(Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND))
.custom_tag(SingleLetterTag::lowercase(Alphabet::P), p_tag)
.since(now);
self.relay_pool
.subscribe(vec![
ephemeral_filter,
gift_wrap_filter,
ephemeral_gift_wrap_filter,
])
.await
}
pub fn convert_event_to_mcp(&self, content: &str) -> Option<JsonRpcMessage> {
validation::validate_and_parse(content)
}
pub async fn create_signed_event(
&self,
message: &JsonRpcMessage,
kind: u16,
tags: Vec<Tag>,
) -> Result<Event> {
let builder = serializers::mcp_to_nostr_event(message, kind, tags)?;
self.relay_pool.sign(builder).await
}
pub async fn prepare_mcp_message(
&self,
message: &JsonRpcMessage,
recipient: &PublicKey,
kind: u16,
tags: Vec<Tag>,
is_encrypted: Option<bool>,
gift_wrap_kind: Option<u16>,
) -> Result<(EventId, Event)> {
let should_encrypt = self.should_encrypt(kind, is_encrypted);
let event = self.create_signed_event(message, kind, tags).await?;
let signed_event_id = event.id;
if should_encrypt {
let event_json =
serde_json::to_string(&event).map_err(|e| Error::Encryption(e.to_string()))?;
let signer = self
.relay_pool
.signer()
.await
.map_err(|e| Error::Encryption(e.to_string()))?;
let selected_gift_wrap_kind = gift_wrap_kind.unwrap_or(GIFT_WRAP_KIND);
let gift_wrap_event = encryption::gift_wrap_single_layer_with_kind(
&signer,
recipient,
&event_json,
selected_gift_wrap_kind,
)
.await?;
tracing::debug!(
target: LOG_TARGET,
signed_event_id = %signed_event_id,
envelope_id = %gift_wrap_event.id,
gift_wrap_kind = selected_gift_wrap_kind,
"Prepared encrypted MCP message"
);
Ok((signed_event_id, gift_wrap_event))
} else {
tracing::debug!(
target: LOG_TARGET,
signed_event_id = %signed_event_id,
"Prepared unencrypted MCP message"
);
Ok((signed_event_id, event))
}
}
pub async fn send_mcp_message(
&self,
message: &JsonRpcMessage,
recipient: &PublicKey,
kind: u16,
tags: Vec<Tag>,
is_encrypted: Option<bool>,
gift_wrap_kind: Option<u16>,
) -> Result<EventId> {
let should_encrypt = self.should_encrypt(kind, is_encrypted);
let event = self.create_signed_event(message, kind, tags).await?;
let signed_event_id = event.id;
if should_encrypt {
let event_json =
serde_json::to_string(&event).map_err(|e| Error::Encryption(e.to_string()))?;
let signer = self
.relay_pool
.signer()
.await
.map_err(|e| Error::Encryption(e.to_string()))?;
let selected_gift_wrap_kind = gift_wrap_kind.unwrap_or(GIFT_WRAP_KIND);
let gift_wrap_event = encryption::gift_wrap_single_layer_with_kind(
&signer,
recipient,
&event_json,
selected_gift_wrap_kind,
)
.await?;
self.relay_pool.publish_event(&gift_wrap_event).await?;
tracing::debug!(
target: LOG_TARGET,
signed_event_id = %signed_event_id,
envelope_id = %gift_wrap_event.id,
gift_wrap_kind = selected_gift_wrap_kind,
"Sent encrypted MCP message"
);
} else {
self.relay_pool.publish_event(&event).await?;
tracing::debug!(
target: LOG_TARGET,
signed_event_id = %signed_event_id,
"Sent unencrypted MCP message"
);
}
Ok(signed_event_id)
}
pub fn should_encrypt(&self, kind: u16, is_encrypted: Option<bool>) -> bool {
if UNENCRYPTED_KINDS.contains(&kind) {
return false;
}
match self.encryption_mode {
EncryptionMode::Disabled => false,
EncryptionMode::Required => true,
EncryptionMode::Optional => is_encrypted.unwrap_or(true),
}
}
pub fn create_recipient_tags(pubkey: &PublicKey) -> Vec<Tag> {
vec![Tag::public_key(*pubkey)]
}
pub fn create_response_tags(pubkey: &PublicKey, event_id: &EventId) -> Vec<Tag> {
vec![Tag::public_key(*pubkey), Tag::event(*event_id)]
}
pub fn compose_outbound_tags(
base_tags: &[Tag],
discovery_tags: &[Tag],
negotiation_tags: &[Tag],
) -> Vec<Tag> {
let mut tags =
Vec::with_capacity(base_tags.len() + discovery_tags.len() + negotiation_tags.len());
tags.extend_from_slice(base_tags);
tags.extend_from_slice(discovery_tags);
tags.extend_from_slice(negotiation_tags);
tags
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::types::*;
fn should_encrypt(mode: EncryptionMode, kind: u16, is_encrypted: Option<bool>) -> bool {
if UNENCRYPTED_KINDS.contains(&kind) {
return false;
}
match mode {
EncryptionMode::Disabled => false,
EncryptionMode::Required => true,
EncryptionMode::Optional => is_encrypted.unwrap_or(true),
}
}
#[test]
fn test_should_encrypt_disabled_mode() {
assert!(!should_encrypt(
EncryptionMode::Disabled,
CTXVM_MESSAGES_KIND,
None
));
assert!(!should_encrypt(
EncryptionMode::Disabled,
CTXVM_MESSAGES_KIND,
Some(true)
));
assert!(!should_encrypt(
EncryptionMode::Disabled,
CTXVM_MESSAGES_KIND,
Some(false)
));
}
#[test]
fn test_should_encrypt_required_mode() {
assert!(should_encrypt(
EncryptionMode::Required,
CTXVM_MESSAGES_KIND,
None
));
assert!(should_encrypt(
EncryptionMode::Required,
CTXVM_MESSAGES_KIND,
Some(false)
));
assert!(should_encrypt(
EncryptionMode::Required,
CTXVM_MESSAGES_KIND,
Some(true)
));
}
#[test]
fn test_should_encrypt_optional_mode() {
assert!(should_encrypt(
EncryptionMode::Optional,
CTXVM_MESSAGES_KIND,
None
));
assert!(should_encrypt(
EncryptionMode::Optional,
CTXVM_MESSAGES_KIND,
Some(true)
));
assert!(!should_encrypt(
EncryptionMode::Optional,
CTXVM_MESSAGES_KIND,
Some(false)
));
}
#[test]
fn test_should_encrypt_announcement_kinds_never_encrypted() {
for &kind in UNENCRYPTED_KINDS {
assert!(!should_encrypt(EncryptionMode::Required, kind, Some(true)));
assert!(!should_encrypt(EncryptionMode::Optional, kind, Some(true)));
assert!(!should_encrypt(EncryptionMode::Disabled, kind, Some(true)));
}
}
#[test]
fn test_create_recipient_tags() {
let keys = Keys::generate();
let pubkey = keys.public_key();
let tags = BaseTransport::create_recipient_tags(&pubkey);
assert_eq!(tags.len(), 1);
let tag_vec = tags[0].clone().to_vec();
assert_eq!(tag_vec[0], "p");
assert_eq!(tag_vec[1], pubkey.to_hex());
}
#[test]
fn test_create_response_tags() {
let keys = Keys::generate();
let pubkey = keys.public_key();
let event_id =
EventId::from_hex("0000000000000000000000000000000000000000000000000000000000000001")
.unwrap();
let tags = BaseTransport::create_response_tags(&pubkey, &event_id);
assert_eq!(tags.len(), 2);
let t0 = tags[0].clone().to_vec();
assert_eq!(t0[0], "p");
assert_eq!(t0[1], pubkey.to_hex());
let t1 = tags[1].clone().to_vec();
assert_eq!(t1[0], "e");
assert_eq!(t1[1], event_id.to_hex());
}
#[test]
fn test_convert_event_to_mcp_valid_request() {
let content = r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#;
let value: serde_json::Value = serde_json::from_str(content).unwrap();
let msg = crate::core::validation::validate_message(&value).unwrap();
assert!(msg.is_request());
assert_eq!(msg.method(), Some("tools/list"));
}
#[test]
fn test_convert_event_to_mcp_valid_notification() {
let content = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
let value: serde_json::Value = serde_json::from_str(content).unwrap();
let msg = crate::core::validation::validate_message(&value).unwrap();
assert!(msg.is_notification());
}
#[test]
fn test_convert_event_to_mcp_valid_response() {
let content = r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}"#;
let value: serde_json::Value = serde_json::from_str(content).unwrap();
let msg = crate::core::validation::validate_message(&value).unwrap();
assert!(msg.is_response());
}
#[test]
fn test_convert_event_to_mcp_invalid_json() {
let content = "not json at all";
let result: std::result::Result<serde_json::Value, _> = serde_json::from_str(content);
assert!(result.is_err());
}
#[test]
fn test_convert_event_to_mcp_invalid_jsonrpc_version() {
let content = r#"{"jsonrpc":"1.0","id":1,"method":"test"}"#;
let value: serde_json::Value = serde_json::from_str(content).unwrap();
assert!(crate::core::validation::validate_message(&value).is_none());
}
#[test]
fn test_convert_event_to_mcp_oversized_message() {
let big = "x".repeat(MAX_MESSAGE_SIZE + 1);
assert!(!crate::core::validation::validate_message_size(&big));
}
fn make_custom_tag(name: &str) -> Tag {
Tag::custom(TagKind::Custom(name.into()), Vec::<String>::new())
}
#[test]
fn compose_outbound_tags_ordering() {
let keys = Keys::generate();
let base = vec![Tag::public_key(keys.public_key())];
let discovery = vec![make_custom_tag("support_encryption")];
let negotiation = vec![make_custom_tag("pmi")];
let result = BaseTransport::compose_outbound_tags(&base, &discovery, &negotiation);
assert_eq!(result.len(), 3);
assert_eq!(result[0].clone().to_vec()[0], "p");
assert_eq!(result[1].clone().to_vec()[0], "support_encryption");
assert_eq!(result[2].clone().to_vec()[0], "pmi");
}
#[test]
fn compose_outbound_tags_empty_discovery() {
let keys = Keys::generate();
let base = vec![Tag::public_key(keys.public_key())];
let negotiation = vec![make_custom_tag("pmi")];
let result = BaseTransport::compose_outbound_tags(&base, &[], &negotiation);
assert_eq!(result.len(), 2);
assert_eq!(result[0].clone().to_vec()[0], "p");
assert_eq!(result[1].clone().to_vec()[0], "pmi");
}
#[test]
fn compose_outbound_tags_all_empty() {
let result = BaseTransport::compose_outbound_tags(&[], &[], &[]);
assert!(result.is_empty());
}
#[test]
fn compose_outbound_tags_preserves_all_elements() {
let discovery = vec![
make_custom_tag("support_encryption"),
make_custom_tag("support_encryption_ephemeral"),
];
let result = BaseTransport::compose_outbound_tags(&[], &discovery, &[]);
assert_eq!(result.len(), 2);
assert_eq!(result[0].clone().to_vec()[0], "support_encryption");
assert_eq!(
result[1].clone().to_vec()[0],
"support_encryption_ephemeral"
);
}
}