#include "joiner.hpp"
#include <stdio.h>
#include "common/code_utils.hpp"
#include "common/debug.hpp"
#include "common/encoding.hpp"
#include "common/instance.hpp"
#include "common/locator-getters.hpp"
#include "common/logging.hpp"
#include "common/string.hpp"
#include "meshcop/meshcop.hpp"
#include "radio/radio.hpp"
#include "thread/thread_netif.hpp"
#include "thread/thread_uri_paths.hpp"
#include "utils/otns.hpp"
#if OPENTHREAD_CONFIG_JOINER_ENABLE
using ot::Encoding::BigEndian::HostSwap16;
namespace ot {
namespace MeshCoP {
Joiner::Joiner(Instance &aInstance)
: InstanceLocator(aInstance)
, mState(OT_JOINER_STATE_IDLE)
, mCallback(NULL)
, mContext(NULL)
, mJoinerRouterIndex(0)
, mFinalizeMessage(NULL)
, mTimer(aInstance, &Joiner::HandleTimer, this)
, mJoinerEntrust(OT_URI_PATH_JOINER_ENTRUST, &Joiner::HandleJoinerEntrust, this)
{
memset(mJoinerRouters, 0, sizeof(mJoinerRouters));
Get<Coap::Coap>().AddResource(mJoinerEntrust);
}
void Joiner::GetJoinerId(Mac::ExtAddress &aJoinerId) const
{
Get<Radio>().GetIeeeEui64(aJoinerId);
ComputeJoinerId(aJoinerId, aJoinerId);
}
void Joiner::SetState(otJoinerState aState)
{
otJoinerState oldState = mState;
OT_UNUSED_VARIABLE(oldState);
SuccessOrExit(Get<Notifier>().Update(mState, aState, OT_CHANGED_JOINER_STATE));
otLogInfoMeshCoP("JoinerState: %s -> %s", JoinerStateToString(oldState), JoinerStateToString(aState));
exit:
return;
}
bool Joiner::IsPskdValid(const char *aPskd)
{
bool valid = false;
size_t pskdLength = StringLength(aPskd, kPskdMaxLength + 1);
OT_STATIC_ASSERT(static_cast<uint8_t>(kPskdMaxLength) <= static_cast<uint8_t>(Dtls::kPskMaxLength),
"The maximum length of DTLS PSK is smaller than joiner PSKd");
VerifyOrExit(pskdLength >= kPskdMinLength && pskdLength <= kPskdMaxLength, OT_NOOP);
for (size_t i = 0; i < pskdLength; i++)
{
char c = aPskd[i];
VerifyOrExit(isdigit(c) || isupper(c), OT_NOOP);
VerifyOrExit(c != 'I' && c != 'O' && c != 'Q' && c != 'Z', OT_NOOP);
}
valid = true;
exit:
return valid;
}
otError Joiner::Start(const char * aPskd,
const char * aProvisioningUrl,
const char * aVendorName,
const char * aVendorModel,
const char * aVendorSwVersion,
const char * aVendorData,
otJoinerCallback aCallback,
void * aContext)
{
otError error;
Mac::ExtAddress randomAddress;
otLogInfoMeshCoP("Joiner starting");
VerifyOrExit(mState == OT_JOINER_STATE_IDLE, error = OT_ERROR_BUSY);
VerifyOrExit(IsPskdValid(aPskd), error = OT_ERROR_INVALID_ARGS);
randomAddress.GenerateRandom();
Get<Mac::Mac>().SetExtAddress(randomAddress);
Get<Mle::MleRouter>().UpdateLinkLocalAddress();
SuccessOrExit(error = Get<Coap::CoapSecure>().Start(kJoinerUdpPort));
SuccessOrExit(error = Get<Coap::CoapSecure>().SetPsk(reinterpret_cast<const uint8_t *>(aPskd),
static_cast<uint8_t>(strlen(aPskd))));
for (JoinerRouter *router = &mJoinerRouters[0]; router < OT_ARRAY_END(mJoinerRouters); router++)
{
router->mPriority = 0; }
SuccessOrExit(error = PrepareJoinerFinalizeMessage(aProvisioningUrl, aVendorName, aVendorModel, aVendorSwVersion,
aVendorData));
SuccessOrExit(error = Get<Mle::MleRouter>().Discover(Mac::ChannelMask(0), Get<Mac::Mac>().GetPanId(),
true, true,
HandleDiscoverResult, this));
mCallback = aCallback;
mContext = aContext;
SetState(OT_JOINER_STATE_DISCOVER);
exit:
if (error != OT_ERROR_NONE)
{
otLogWarnMeshCoP("Failed to start joiner: %s", otThreadErrorToString(error));
FreeJoinerFinalizeMessage();
}
return error;
}
void Joiner::Stop(void)
{
otLogInfoMeshCoP("Joiner stopped");
mCallback = NULL;
Finish(OT_ERROR_ABORT);
}
void Joiner::Finish(otError aError)
{
switch (mState)
{
case OT_JOINER_STATE_IDLE:
ExitNow();
case OT_JOINER_STATE_CONNECT:
case OT_JOINER_STATE_CONNECTED:
case OT_JOINER_STATE_ENTRUST:
case OT_JOINER_STATE_JOINED:
Get<Coap::CoapSecure>().Disconnect();
IgnoreError(Get<Ip6::Filter>().RemoveUnsecurePort(kJoinerUdpPort));
mTimer.Stop();
case OT_JOINER_STATE_DISCOVER:
Get<Coap::CoapSecure>().Stop();
break;
}
SetState(OT_JOINER_STATE_IDLE);
FreeJoinerFinalizeMessage();
if (mCallback)
{
mCallback(aError, mContext);
}
exit:
return;
}
uint8_t Joiner::CalculatePriority(int8_t aRssi, bool aSteeringDataAllowsAny)
{
int16_t priority;
if (aRssi == OT_RADIO_RSSI_INVALID)
{
aRssi = -127;
}
if (aRssi <= -128)
{
priority = -127;
}
else if (aRssi >= 0)
{
priority = -1;
}
else
{
priority = aRssi;
}
priority += aSteeringDataAllowsAny ? 128 : 256;
return static_cast<uint8_t>(priority);
}
void Joiner::HandleDiscoverResult(otActiveScanResult *aResult, void *aContext)
{
static_cast<Joiner *>(aContext)->HandleDiscoverResult(aResult);
}
void Joiner::HandleDiscoverResult(otActiveScanResult *aResult)
{
Mac::ExtAddress joinerId;
VerifyOrExit(mState == OT_JOINER_STATE_DISCOVER, OT_NOOP);
if (aResult != NULL)
{
SaveDiscoveredJoinerRouter(*aResult);
}
else
{
GetJoinerId(joinerId);
Get<Mac::Mac>().SetExtAddress(joinerId);
Get<Mle::MleRouter>().UpdateLinkLocalAddress();
mJoinerRouterIndex = 0;
TryNextJoinerRouter(OT_ERROR_NONE);
}
exit:
return;
}
void Joiner::SaveDiscoveredJoinerRouter(const otActiveScanResult &aResult)
{
uint8_t priority;
bool doesAllowAny = true;
JoinerRouter *end = OT_ARRAY_END(mJoinerRouters);
JoinerRouter *entry;
for (uint8_t i = 0; i < aResult.mSteeringData.mLength; i++)
{
if (aResult.mSteeringData.m8[i] != 0xff)
{
doesAllowAny = false;
break;
}
}
otLogInfoMeshCoP("Joiner discover network: %s, pan:0x%04x, port:%d, chan:%d, rssi:%d, allow-any:%s",
static_cast<const Mac::ExtAddress &>(aResult.mExtAddress).ToString().AsCString(), aResult.mPanId,
aResult.mJoinerUdpPort, aResult.mChannel, aResult.mRssi, doesAllowAny ? "yes" : "no");
priority = CalculatePriority(aResult.mRssi, doesAllowAny);
for (entry = &mJoinerRouters[0]; entry < end; entry++)
{
if (priority > entry->mPriority)
{
break;
}
}
VerifyOrExit(entry < end, OT_NOOP);
memmove(entry + 1, entry,
static_cast<size_t>(reinterpret_cast<uint8_t *>(end - 1) - reinterpret_cast<uint8_t *>(entry)));
entry->mExtAddr = static_cast<const Mac::ExtAddress &>(aResult.mExtAddress);
entry->mPanId = aResult.mPanId;
entry->mJoinerUdpPort = aResult.mJoinerUdpPort;
entry->mChannel = aResult.mChannel;
entry->mPriority = priority;
exit:
return;
}
void Joiner::TryNextJoinerRouter(otError aPrevError)
{
for (; mJoinerRouterIndex < OT_ARRAY_LENGTH(mJoinerRouters); mJoinerRouterIndex++)
{
JoinerRouter &router = mJoinerRouters[mJoinerRouterIndex];
otError error;
if (router.mPriority == 0)
{
break;
}
error = Connect(router);
VerifyOrExit(error != OT_ERROR_NONE, mJoinerRouterIndex++);
if (aPrevError == OT_ERROR_NONE)
{
aPrevError = error;
}
}
if (aPrevError == OT_ERROR_NONE)
{
aPrevError = OT_ERROR_NOT_FOUND;
}
Finish(aPrevError);
exit:
return;
}
otError Joiner::Connect(JoinerRouter &aRouter)
{
otError error = OT_ERROR_NOT_FOUND;
Ip6::SockAddr sockaddr;
otLogInfoMeshCoP("Joiner connecting to %s, pan:0x%04x, chan:%d", aRouter.mExtAddr.ToString().AsCString(),
aRouter.mPanId, aRouter.mChannel);
Get<Mac::Mac>().SetPanId(aRouter.mPanId);
SuccessOrExit(error = Get<Mac::Mac>().SetPanChannel(aRouter.mChannel));
SuccessOrExit(error = Get<Ip6::Filter>().AddUnsecurePort(kJoinerUdpPort));
sockaddr.GetAddress().SetToLinkLocalAddress(aRouter.mExtAddr);
sockaddr.mPort = aRouter.mJoinerUdpPort;
SuccessOrExit(error = Get<Coap::CoapSecure>().Connect(sockaddr, Joiner::HandleSecureCoapClientConnect, this));
SetState(OT_JOINER_STATE_CONNECT);
exit:
if (error != OT_ERROR_NONE)
{
otLogWarnMeshCoP("Failed to start secure joiner connection: %s", otThreadErrorToString(error));
}
return error;
}
void Joiner::HandleSecureCoapClientConnect(bool aConnected, void *aContext)
{
static_cast<Joiner *>(aContext)->HandleSecureCoapClientConnect(aConnected);
}
void Joiner::HandleSecureCoapClientConnect(bool aConnected)
{
VerifyOrExit(mState == OT_JOINER_STATE_CONNECT, OT_NOOP);
if (aConnected)
{
SetState(OT_JOINER_STATE_CONNECTED);
SendJoinerFinalize();
mTimer.Start(kReponseTimeout);
}
else
{
TryNextJoinerRouter(OT_ERROR_SECURITY);
}
exit:
return;
}
otError Joiner::PrepareJoinerFinalizeMessage(const char *aProvisioningUrl,
const char *aVendorName,
const char *aVendorModel,
const char *aVendorSwVersion,
const char *aVendorData)
{
otError error = OT_ERROR_NONE;
VendorNameTlv vendorNameTlv;
VendorModelTlv vendorModelTlv;
VendorSwVersionTlv vendorSwVersionTlv;
VendorStackVersionTlv vendorStackVersionTlv;
ProvisioningUrlTlv provisioningUrlTlv;
VerifyOrExit((mFinalizeMessage = NewMeshCoPMessage(Get<Coap::CoapSecure>())) != NULL, error = OT_ERROR_NO_BUFS);
mFinalizeMessage->Init(OT_COAP_TYPE_CONFIRMABLE, OT_COAP_CODE_POST);
SuccessOrExit(error = mFinalizeMessage->AppendUriPathOptions(OT_URI_PATH_JOINER_FINALIZE));
SuccessOrExit(error = mFinalizeMessage->SetPayloadMarker());
mFinalizeMessage->SetOffset(mFinalizeMessage->GetLength());
SuccessOrExit(error = Tlv::AppendUint8Tlv(*mFinalizeMessage, Tlv::kState, StateTlv::kAccept));
vendorNameTlv.Init();
vendorNameTlv.SetVendorName(aVendorName);
SuccessOrExit(error = vendorNameTlv.AppendTo(*mFinalizeMessage));
vendorModelTlv.Init();
vendorModelTlv.SetVendorModel(aVendorModel);
SuccessOrExit(error = vendorModelTlv.AppendTo(*mFinalizeMessage));
vendorSwVersionTlv.Init();
vendorSwVersionTlv.SetVendorSwVersion(aVendorSwVersion);
SuccessOrExit(error = vendorSwVersionTlv.AppendTo(*mFinalizeMessage));
vendorStackVersionTlv.Init();
vendorStackVersionTlv.SetOui(OPENTHREAD_CONFIG_STACK_VENDOR_OUI);
vendorStackVersionTlv.SetMajor(OPENTHREAD_CONFIG_STACK_VERSION_MAJOR);
vendorStackVersionTlv.SetMinor(OPENTHREAD_CONFIG_STACK_VERSION_MINOR);
vendorStackVersionTlv.SetRevision(OPENTHREAD_CONFIG_STACK_VERSION_REV);
SuccessOrExit(error = vendorStackVersionTlv.AppendTo(*mFinalizeMessage));
if (aVendorData != NULL)
{
VendorDataTlv vendorDataTlv;
vendorDataTlv.Init();
vendorDataTlv.SetVendorData(aVendorData);
SuccessOrExit(error = vendorDataTlv.AppendTo(*mFinalizeMessage));
}
provisioningUrlTlv.Init();
provisioningUrlTlv.SetProvisioningUrl(aProvisioningUrl);
if (provisioningUrlTlv.GetLength() > 0)
{
SuccessOrExit(error = provisioningUrlTlv.AppendTo(*mFinalizeMessage));
}
exit:
if (error != OT_ERROR_NONE)
{
FreeJoinerFinalizeMessage();
}
return error;
}
void Joiner::FreeJoinerFinalizeMessage(void)
{
VerifyOrExit(mState == OT_JOINER_STATE_IDLE && mFinalizeMessage != NULL, OT_NOOP);
mFinalizeMessage->Free();
mFinalizeMessage = NULL;
exit:
return;
}
void Joiner::SendJoinerFinalize(void)
{
OT_ASSERT(mFinalizeMessage != NULL);
#if OPENTHREAD_CONFIG_REFERENCE_DEVICE_ENABLE
LogCertMessage("[THCI] direction=send | type=JOIN_FIN.req |", *mFinalizeMessage);
#endif
SuccessOrExit(Get<Coap::CoapSecure>().SendMessage(*mFinalizeMessage, Joiner::HandleJoinerFinalizeResponse, this));
mFinalizeMessage = NULL;
otLogInfoMeshCoP("Joiner sent finalize");
exit:
return;
}
void Joiner::HandleJoinerFinalizeResponse(void * aContext,
otMessage * aMessage,
const otMessageInfo *aMessageInfo,
otError aResult)
{
static_cast<Joiner *>(aContext)->HandleJoinerFinalizeResponse(
*static_cast<Coap::Message *>(aMessage), static_cast<const Ip6::MessageInfo *>(aMessageInfo), aResult);
}
void Joiner::HandleJoinerFinalizeResponse(Coap::Message & aMessage,
const Ip6::MessageInfo *aMessageInfo,
otError aResult)
{
OT_UNUSED_VARIABLE(aMessageInfo);
uint8_t state;
VerifyOrExit(mState == OT_JOINER_STATE_CONNECTED && aResult == OT_ERROR_NONE && aMessage.IsAck() &&
aMessage.GetCode() == OT_COAP_CODE_CHANGED,
OT_NOOP);
SuccessOrExit(Tlv::FindUint8Tlv(aMessage, Tlv::kState, state));
SetState(OT_JOINER_STATE_ENTRUST);
mTimer.Start(kReponseTimeout);
otLogInfoMeshCoP("Joiner received finalize response %d", state);
#if OPENTHREAD_CONFIG_REFERENCE_DEVICE_ENABLE
LogCertMessage("[THCI] direction=recv | type=JOIN_FIN.rsp |", aMessage);
#endif
exit:
Get<Coap::CoapSecure>().Disconnect();
IgnoreError(Get<Ip6::Filter>().RemoveUnsecurePort(kJoinerUdpPort));
}
void Joiner::HandleJoinerEntrust(void *aContext, otMessage *aMessage, const otMessageInfo *aMessageInfo)
{
static_cast<Joiner *>(aContext)->HandleJoinerEntrust(*static_cast<Coap::Message *>(aMessage),
*static_cast<const Ip6::MessageInfo *>(aMessageInfo));
}
void Joiner::HandleJoinerEntrust(Coap::Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
{
otError error;
otOperationalDataset dataset;
VerifyOrExit(mState == OT_JOINER_STATE_ENTRUST && aMessage.IsConfirmable() &&
aMessage.GetCode() == OT_COAP_CODE_POST,
error = OT_ERROR_DROP);
otLogInfoMeshCoP("Joiner received entrust");
otLogCertMeshCoP("[THCI] direction=recv | type=JOIN_ENT.ntf");
memset(&dataset, 0, sizeof(dataset));
SuccessOrExit(error = Tlv::FindTlv(aMessage, Tlv::kNetworkMasterKey, &dataset.mMasterKey, sizeof(MasterKey)));
dataset.mComponents.mIsMasterKeyPresent = true;
dataset.mChannel = Get<Mac::Mac>().GetPanChannel();
dataset.mComponents.mIsChannelPresent = true;
dataset.mPanId = Get<Mac::Mac>().GetPanId();
dataset.mComponents.mIsPanIdPresent = true;
IgnoreError(Get<MeshCoP::ActiveDataset>().Save(dataset));
otLogInfoMeshCoP("Joiner successful!");
SendJoinerEntrustResponse(aMessage, aMessageInfo);
mTimer.Start(kConfigExtAddressDelay);
exit:
if (error != OT_ERROR_NONE)
{
otLogWarnMeshCoP("Failed to process joiner entrust: %s", otThreadErrorToString(error));
}
}
void Joiner::SendJoinerEntrustResponse(const Coap::Message &aRequest, const Ip6::MessageInfo &aRequestInfo)
{
otError error = OT_ERROR_NONE;
Coap::Message * message;
Ip6::MessageInfo responseInfo(aRequestInfo);
VerifyOrExit((message = NewMeshCoPMessage(Get<Coap::Coap>())) != NULL, error = OT_ERROR_NO_BUFS);
SuccessOrExit(error = message->SetDefaultResponseHeader(aRequest));
message->SetSubType(Message::kSubTypeJoinerEntrust);
responseInfo.GetSockAddr().Clear();
SuccessOrExit(error = Get<Coap::Coap>().SendMessage(*message, responseInfo));
SetState(OT_JOINER_STATE_JOINED);
otLogInfoMeshCoP("Joiner sent entrust response");
otLogCertMeshCoP("[THCI] direction=send | type=JOIN_ENT.rsp");
exit:
if (error != OT_ERROR_NONE && message != NULL)
{
message->Free();
}
}
void Joiner::HandleTimer(Timer &aTimer)
{
aTimer.GetOwner<Joiner>().HandleTimer();
}
void Joiner::HandleTimer(void)
{
otError error = OT_ERROR_NONE;
switch (mState)
{
case OT_JOINER_STATE_IDLE:
case OT_JOINER_STATE_DISCOVER:
case OT_JOINER_STATE_CONNECT:
OT_ASSERT(false);
OT_UNREACHABLE_CODE(break);
case OT_JOINER_STATE_CONNECTED:
case OT_JOINER_STATE_ENTRUST:
error = OT_ERROR_RESPONSE_TIMEOUT;
break;
case OT_JOINER_STATE_JOINED:
Mac::ExtAddress extAddress;
extAddress.GenerateRandom();
Get<Mac::Mac>().SetExtAddress(extAddress);
Get<Mle::MleRouter>().UpdateLinkLocalAddress();
error = OT_ERROR_NONE;
break;
}
Finish(error);
}
const char *Joiner::JoinerStateToString(otJoinerState aState)
{
const char *str = "Unknown";
switch (aState)
{
case OT_JOINER_STATE_IDLE:
str = "Idle";
break;
case OT_JOINER_STATE_DISCOVER:
str = "Discover";
break;
case OT_JOINER_STATE_CONNECT:
str = "Connecting";
break;
case OT_JOINER_STATE_CONNECTED:
str = "Connected";
break;
case OT_JOINER_STATE_ENTRUST:
str = "Entrust";
break;
case OT_JOINER_STATE_JOINED:
str = "Joined";
break;
}
return str;
}
#if OPENTHREAD_CONFIG_REFERENCE_DEVICE_ENABLE
void Joiner::LogCertMessage(const char *aText, const Coap::Message &aMessage) const
{
uint8_t buf[OPENTHREAD_CONFIG_MESSAGE_BUFFER_SIZE];
VerifyOrExit(aMessage.GetLength() <= sizeof(buf), OT_NOOP);
aMessage.Read(aMessage.GetOffset(), aMessage.GetLength() - aMessage.GetOffset(), buf);
otDumpCertMeshCoP(aText, buf, aMessage.GetLength() - aMessage.GetOffset());
exit:
return;
}
#endif
} }
#endif